Add a cluster-capable message persister

This commit is contained in:
Jon Chambers 2020-07-21 11:47:53 -04:00
parent f9f93c77e2
commit beac73b6c8
5 changed files with 500 additions and 27 deletions

View File

@ -0,0 +1,181 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.time.Duration;
import java.time.Instant;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagePersister implements Managed {
private final RedisClusterMessagesCache messagesCache;
private final Messages messagesDatabase;
private final PubSubManager pubSubManager;
private final PushSender pushSender;
private final AccountsManager accountsManager;
private final Duration persistDelay;
private volatile boolean running = false;
private Thread workerThread;
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "getQueues"));
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "persistQueue"));
private static final Timer NOTIFY_SUBSCRIBERS_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "notifySubscribers"));
private static final DistributionSummary QUEUE_COUNT_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueCount"));
private static final DistributionSummary QUEUE_SIZE_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueSize"));
static final int QUEUE_BATCH_LIMIT = 100;
static final int MESSAGE_BATCH_LIMIT = 100;
private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagePersister.class);
public RedisClusterMessagePersister(final RedisClusterMessagesCache messagesCache, final Messages messagesDatabase, final PubSubManager pubSubManager, final PushSender pushSender, final AccountsManager accountsManager, final Duration persistDelay) {
this.messagesCache = messagesCache;
this.messagesDatabase = messagesDatabase;
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.persistDelay = persistDelay;
}
@Override
public void start() {
running = true;
workerThread = new Thread(() -> {
while (running) {
persistNextQueues(Instant.now());
Util.sleep(100);
}
});
workerThread.start();
}
@Override
public void stop() throws Exception {
running = false;
if (workerThread != null) {
workerThread.join();
workerThread = null;
}
}
@VisibleForTesting
void persistNextQueues(final Instant currentTime) {
final int slot = messagesCache.getNextSlotToPersist();
List<String> queuesToPersist;
int queuesPersisted = 0;
do {
queuesToPersist = GET_QUEUES_TIMER.record(() -> messagesCache.getQueuesToPersist(slot, currentTime.minus(persistDelay), QUEUE_BATCH_LIMIT));
for (final String queue : queuesToPersist) {
persistQueue(queue);
notifyClients(RedisClusterMessagesCache.getAccountUuidFromQueueName(queue), RedisClusterMessagesCache.getDeviceIdFromQueueName(queue));
}
queuesPersisted += queuesToPersist.size();
} while (queuesToPersist.size() == QUEUE_BATCH_LIMIT);
QUEUE_COUNT_SUMMARY.record(queuesPersisted);
}
@VisibleForTesting
void persistQueue(final String queue) {
final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queue);
final long deviceId = RedisClusterMessagesCache.getDeviceIdFromQueueName(queue);
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
final String accountNumber;
if (maybeAccount.isPresent()) {
accountNumber = maybeAccount.get().getNumber();
} else {
logger.error("No account record found for account {}", accountUuid);
return;
}
PERSIST_QUEUE_TIMER.record(() -> {
messagesCache.lockQueueForPersistence(queue);
try {
int messageCount = 0;
List<MessageProtos.Envelope> messages;
do {
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
for (final MessageProtos.Envelope message : messages) {
final UUID uuid = UUID.fromString(message.getServerGuid());
messagesDatabase.store(uuid, message, accountNumber, deviceId);
messagesCache.remove(accountNumber, accountUuid, deviceId, uuid);
messageCount++;
}
} while (messages.size() == MESSAGE_BATCH_LIMIT);
QUEUE_SIZE_SUMMARY.record(messageCount);
} finally {
messagesCache.unlockQueueForPersistence(queue);
}
});
}
public void notifyClients(final UUID accountUuid, final long deviceId) {
NOTIFY_SUBSCRIBERS_TIMER.record(() -> {
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
final String address;
if (maybeAccount.isPresent()) {
address = maybeAccount.get().getNumber();
} else {
logger.error("No account record found for account {}", accountUuid);
return;
}
final boolean notified = pubSubManager.publish(new WebsocketAddress(address, deviceId),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build());
if (!notified) {
Optional<Account> account = accountsManager.get(address);
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(deviceId);
if (device.isPresent()) {
try {
pushSender.sendQueuedNotification(account.get(), device.get());
} catch (final NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!");
}
}
}
}
});
}
}

View File

@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -13,6 +15,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@ -23,12 +26,18 @@ import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagesCache implements UserMessagesCache {
private final FaultTolerantRedisCluster redisCluster;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByIdScript;
private final ClusterLuaScript removeBySenderScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private final ClusterLuaScript getQueuesToPersistScript;
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
@ -44,12 +53,15 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
this.redisCluster = redisCluster;
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI);
}
@Override
@ -122,13 +134,13 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID guid) {
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
getQueueIndexKey(destinationUuid, destinationDevice)),
List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
@ -142,11 +154,11 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
@Override
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(String destination, final UUID destinationUuid, long destinationDevice, int limit) {
public List<OutgoingMessageEntity> get(final String destination, final UUID destinationUuid, final long destinationDevice, final int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
getPersistInProgressKey(destinationUuid, destinationDevice)),
List.of(String.valueOf(limit).getBytes()));
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
final List<OutgoingMessageEntity> messageEntities;
@ -172,6 +184,35 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
});
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final long destinationDevice, final int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(accountUuid, destinationDevice),
getPersistInProgressKey(accountUuid, destinationDevice)),
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
final List<MessageProtos.Envelope> envelopes;
if (queueItems.size() % 2 == 0) {
envelopes = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size(); i += 2) {
try {
envelopes.add(MessageProtos.Envelope.parseFrom(queueItems.get(i)));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
} else {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
envelopes = Collections.emptyList();
}
return envelopes;
});
}
@Override
public void clear(final String destination, final UUID destinationUuid) {
// TODO Remove null check in a fully UUID-based world
@ -191,7 +232,27 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
Collections.emptyList()));
}
private static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
int getNextSlotToPersist() {
return (int)(redisCluster.withWriteCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT);
}
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
//noinspection unchecked
return (List<String>)getQueuesToPersistScript.execute(List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
List.of(String.valueOf(maxTime.toEpochMilli()),
String.valueOf(limit)));
}
void lockQueueForPersistence(final String queue) {
redisCluster.useBinaryWriteCluster(connection -> connection.sync().setex(getPersistInProgressKey(queue), 30, LOCK_VALUE));
}
void unlockQueueForPersistence(final String queue) {
redisCluster.useBinaryWriteCluster(connection -> connection.sync().del(getPersistInProgressKey(queue)));
}
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
@ -199,11 +260,29 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(accountUuid.toString() + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
}
private byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
return ("user_queue_persisting::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
private static byte[] getQueueIndexKey(final int slot) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
return getPersistInProgressKey(accountUuid + "::" + deviceId);
}
private static byte[] getPersistInProgressKey(final String queueName) {
return ("user_queue_persisting::{" + queueName + "}").getBytes(StandardCharsets.UTF_8);
}
static UUID getAccountUuidFromQueueName(final String queueName) {
final int startOfHashTag = queueName.indexOf('{');
return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag)));
}
static long getDeviceIdFromQueueName(final String queueName) {
return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
}
}

View File

@ -1,8 +1,6 @@
package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
public class RedisClusterUtil {
@ -23,15 +21,7 @@ public class RedisClusterUtil {
}
}
/**
* Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
*
* @param key the key for which to find a matching hash tag
* @return a Redis hash tag that maps to the same Redis cluster slot as the given key
*
* @see <a href="https://redis.io/topics/cluster-spec#keys-hash-tags">Redis Cluster Specification - Keys hash tags</a>
*/
public static String getMinimalHashTag(final String key) {
return HASHES_BY_SLOT[SlotHash.getSlot(key)];
public static String getMinimalHashTag(final int slot) {
return HASHES_BY_SLOT[slot];
}
}

View File

@ -0,0 +1,189 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PushSender;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class RedisClusterMessagePersisterTest {
private RedisClusterMessagesCache messagesCache;
private Messages messagesDatabase;
private RedisClusterMessagePersister messagePersister;
private AccountsManager accountsManager;
private long serialTimestamp = 0;
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
private static final Random RANDOM = new Random();
@Before
public void setUp() {
messagesCache = mock(RedisClusterMessagesCache.class);
messagesDatabase = mock(Messages.class);
accountsManager = mock(AccountsManager.class);
final Account account = mock(Account.class);
when(accountsManager.get(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
messagePersister = new RedisClusterMessagePersister(messagesCache, messagesDatabase, mock(PubSubManager.class), mock(PushSender.class), accountsManager, PERSIST_DELAY);
}
@Test
public void testPersistNextQueuesNoQueues() {
final int slot = 7;
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(Collections.emptyList());
messagePersister.persistNextQueues(Instant.now());
verify(messagesCache, never()).lockQueueForPersistence(any());
}
@Test
public void testPersistNextQueuesSingleQueue() {
final int slot = 7;
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(List.of(queueName));
messagePersister.persistNextQueues(Instant.now());
verify(messagesCache).lockQueueForPersistence(queueName);
}
@Test
public void testPersistNextQueuesMultiplePages() {
final int slot = 7;
final int queueCount = RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3;
final List<String> queues = new ArrayList<>(queueCount);
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueName();
final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queueName);
queues.add(queueName);
final Account account = mock(Account.class);
when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn("+1" + RandomStringUtils.randomNumeric(10));
}
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT)))
.thenReturn(queues.subList(0, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))
.thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2))
.thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3))
.thenReturn(Collections.emptyList());
messagePersister.persistNextQueues(Instant.now());
verify(messagesCache, times(queueCount)).lockQueueForPersistence(any());
}
@Test
public void testPersistQueueNoMessages() {
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(Collections.emptyList());
messagePersister.persistQueue(queueName);
verify(messagesCache).lockQueueForPersistence(queueName);
verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
verify(messagesDatabase, never()).store(any(), any(), any(), anyLong());
verify(messagesCache, never()).remove(anyString(), any(UUID.class), anyLong(), any(UUID.class));
verify(messagesCache).unlockQueueForPersistence(queueName);
}
@Test
public void testPersistQueueSingleMessage() {
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final MessageProtos.Envelope message = generateRandomMessage();
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(List.of(message));
messagePersister.persistQueue(queueName);
verify(messagesCache).lockQueueForPersistence(queueName);
verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
verify(messagesDatabase).store(UUID.fromString(message.getServerGuid()), message, DESTINATION_ACCOUNT_NUMBER, DESTINATION_DEVICE_ID);
verify(messagesCache).remove(DESTINATION_ACCOUNT_NUMBER, DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, UUID.fromString(message.getServerGuid()));
verify(messagesCache).unlockQueueForPersistence(queueName);
}
@Test
public void testPersistQueueMultiplePages() {
final int messageCount = RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3;
final List<MessageProtos.Envelope> messagesInQueue = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messagesInQueue.add(generateRandomMessage());
}
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT))
.thenReturn(messagesInQueue.subList(0, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT))
.thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2))
.thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3))
.thenReturn(Collections.emptyList());
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
messagePersister.persistQueue(queueName);
verify(messagesCache).lockQueueForPersistence(queueName);
verify(messagesCache, times(4)).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
verify(messagesDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID));
verify(messagesCache, times(messageCount)).remove(eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID), any(UUID.class));
verify(messagesCache).unlockQueueForPersistence(queueName);
}
private MessageProtos.Envelope generateRandomMessage() {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(serialTimestamp++)
.setServerTimestamp(serialTimestamp++)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(UUID.randomUUID().toString());
return envelopeBuilder.build();
}
private String generateRandomQueueName() {
return String.format("user_queue::{%s::%d}", UUID.randomUUID().toString(), RANDOM.nextInt(10));
}
}

View File

@ -1,13 +1,18 @@
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.cluster.SlotHash;
import junitparams.Parameters;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.List;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
@ -52,4 +57,33 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
// We're happy as long as this doesn't throw an exception
messagesCache.clear(DESTINATION_ACCOUNT, null);
}
@Test
public void testGetAccountFromQueueName() {
assertEquals(DESTINATION_UUID,
RedisClusterMessagesCache.getAccountUuidFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8)));
}
@Test
public void testGetDeviceIdFromQueueName() {
assertEquals(DESTINATION_DEVICE_ID,
RedisClusterMessagesCache.getDeviceIdFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8)));
}
@Test
@Parameters({"true", "false"})
public void testGetQueuesToPersist(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender));
final int slot = SlotHash.getSlot(DESTINATION_UUID.toString() + "::" + DESTINATION_DEVICE_ID);
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
final List<String> queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100);
assertEquals(1, queues.size());
assertEquals(DESTINATION_UUID, RedisClusterMessagesCache.getAccountUuidFromQueueName(queues.get(0)));
assertEquals(DESTINATION_DEVICE_ID, RedisClusterMessagesCache.getDeviceIdFromQueueName(queues.get(0)));
}
}