Add a cluster-capable message persister
This commit is contained in:
parent
f9f93c77e2
commit
beac73b6c8
|
@ -0,0 +1,181 @@
|
|||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.dropwizard.lifecycle.Managed;
|
||||
import io.micrometer.core.instrument.DistributionSummary;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.Timer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||
import org.whispersystems.textsecuregcm.push.PushSender;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
|
||||
public class RedisClusterMessagePersister implements Managed {
|
||||
|
||||
private final RedisClusterMessagesCache messagesCache;
|
||||
private final Messages messagesDatabase;
|
||||
private final PubSubManager pubSubManager;
|
||||
private final PushSender pushSender;
|
||||
private final AccountsManager accountsManager;
|
||||
|
||||
private final Duration persistDelay;
|
||||
|
||||
private volatile boolean running = false;
|
||||
private Thread workerThread;
|
||||
|
||||
private static final Timer GET_QUEUES_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "getQueues"));
|
||||
private static final Timer PERSIST_QUEUE_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "persistQueue"));
|
||||
private static final Timer NOTIFY_SUBSCRIBERS_TIMER = Metrics.timer(name(RedisClusterMessagePersister.class, "notifySubscribers"));
|
||||
private static final DistributionSummary QUEUE_COUNT_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueCount"));
|
||||
private static final DistributionSummary QUEUE_SIZE_SUMMARY = Metrics.summary(name(RedisClusterMessagePersister.class, "queueSize"));
|
||||
|
||||
static final int QUEUE_BATCH_LIMIT = 100;
|
||||
static final int MESSAGE_BATCH_LIMIT = 100;
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagePersister.class);
|
||||
|
||||
public RedisClusterMessagePersister(final RedisClusterMessagesCache messagesCache, final Messages messagesDatabase, final PubSubManager pubSubManager, final PushSender pushSender, final AccountsManager accountsManager, final Duration persistDelay) {
|
||||
this.messagesCache = messagesCache;
|
||||
this.messagesDatabase = messagesDatabase;
|
||||
this.pubSubManager = pubSubManager;
|
||||
this.pushSender = pushSender;
|
||||
this.accountsManager = accountsManager;
|
||||
|
||||
this.persistDelay = persistDelay;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
running = true;
|
||||
|
||||
workerThread = new Thread(() -> {
|
||||
while (running) {
|
||||
persistNextQueues(Instant.now());
|
||||
Util.sleep(100);
|
||||
}
|
||||
});
|
||||
|
||||
workerThread.start();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop() throws Exception {
|
||||
running = false;
|
||||
|
||||
if (workerThread != null) {
|
||||
workerThread.join();
|
||||
workerThread = null;
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void persistNextQueues(final Instant currentTime) {
|
||||
final int slot = messagesCache.getNextSlotToPersist();
|
||||
|
||||
List<String> queuesToPersist;
|
||||
int queuesPersisted = 0;
|
||||
|
||||
do {
|
||||
queuesToPersist = GET_QUEUES_TIMER.record(() -> messagesCache.getQueuesToPersist(slot, currentTime.minus(persistDelay), QUEUE_BATCH_LIMIT));
|
||||
|
||||
for (final String queue : queuesToPersist) {
|
||||
persistQueue(queue);
|
||||
notifyClients(RedisClusterMessagesCache.getAccountUuidFromQueueName(queue), RedisClusterMessagesCache.getDeviceIdFromQueueName(queue));
|
||||
}
|
||||
|
||||
queuesPersisted += queuesToPersist.size();
|
||||
} while (queuesToPersist.size() == QUEUE_BATCH_LIMIT);
|
||||
|
||||
QUEUE_COUNT_SUMMARY.record(queuesPersisted);
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void persistQueue(final String queue) {
|
||||
final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queue);
|
||||
final long deviceId = RedisClusterMessagesCache.getDeviceIdFromQueueName(queue);
|
||||
|
||||
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
|
||||
|
||||
final String accountNumber;
|
||||
|
||||
if (maybeAccount.isPresent()) {
|
||||
accountNumber = maybeAccount.get().getNumber();
|
||||
} else {
|
||||
logger.error("No account record found for account {}", accountUuid);
|
||||
return;
|
||||
}
|
||||
|
||||
PERSIST_QUEUE_TIMER.record(() -> {
|
||||
messagesCache.lockQueueForPersistence(queue);
|
||||
|
||||
try {
|
||||
int messageCount = 0;
|
||||
List<MessageProtos.Envelope> messages;
|
||||
|
||||
do {
|
||||
messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT);
|
||||
|
||||
for (final MessageProtos.Envelope message : messages) {
|
||||
final UUID uuid = UUID.fromString(message.getServerGuid());
|
||||
|
||||
messagesDatabase.store(uuid, message, accountNumber, deviceId);
|
||||
messagesCache.remove(accountNumber, accountUuid, deviceId, uuid);
|
||||
|
||||
messageCount++;
|
||||
}
|
||||
} while (messages.size() == MESSAGE_BATCH_LIMIT);
|
||||
|
||||
QUEUE_SIZE_SUMMARY.record(messageCount);
|
||||
} finally {
|
||||
messagesCache.unlockQueueForPersistence(queue);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public void notifyClients(final UUID accountUuid, final long deviceId) {
|
||||
NOTIFY_SUBSCRIBERS_TIMER.record(() -> {
|
||||
final Optional<Account> maybeAccount = accountsManager.get(accountUuid);
|
||||
|
||||
final String address;
|
||||
|
||||
if (maybeAccount.isPresent()) {
|
||||
address = maybeAccount.get().getNumber();
|
||||
} else {
|
||||
logger.error("No account record found for account {}", accountUuid);
|
||||
return;
|
||||
}
|
||||
|
||||
final boolean notified = pubSubManager.publish(new WebsocketAddress(address, deviceId),
|
||||
PubSubProtos.PubSubMessage.newBuilder()
|
||||
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
|
||||
.build());
|
||||
|
||||
if (!notified) {
|
||||
Optional<Account> account = accountsManager.get(address);
|
||||
|
||||
if (account.isPresent()) {
|
||||
Optional<Device> device = account.get().getDevice(deviceId);
|
||||
|
||||
if (device.isPresent()) {
|
||||
try {
|
||||
pushSender.sendQueuedNotification(account.get(), device.get());
|
||||
} catch (final NotPushRegisteredException e) {
|
||||
logger.warn("After message persistence, no longer push registered!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -1,7 +1,9 @@
|
|||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import io.lettuce.core.cluster.SlotHash;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -13,6 +15,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -23,12 +26,18 @@ import static com.codahale.metrics.MetricRegistry.name;
|
|||
|
||||
public class RedisClusterMessagesCache implements UserMessagesCache {
|
||||
|
||||
private final FaultTolerantRedisCluster redisCluster;
|
||||
|
||||
private final ClusterLuaScript insertScript;
|
||||
private final ClusterLuaScript removeByIdScript;
|
||||
private final ClusterLuaScript removeBySenderScript;
|
||||
private final ClusterLuaScript removeByGuidScript;
|
||||
private final ClusterLuaScript getItemsScript;
|
||||
private final ClusterLuaScript removeQueueScript;
|
||||
private final ClusterLuaScript getQueuesToPersistScript;
|
||||
|
||||
static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot";
|
||||
private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
|
||||
private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
|
||||
|
@ -44,12 +53,15 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
|
||||
public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
|
||||
|
||||
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
|
||||
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
|
||||
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
|
||||
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
|
||||
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
|
||||
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
|
||||
this.redisCluster = redisCluster;
|
||||
|
||||
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
|
||||
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
|
||||
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
|
||||
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
|
||||
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
|
||||
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
|
||||
this.getQueuesToPersistScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_queues_to_persist.lua", ScriptOutputType.MULTI);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -122,13 +134,13 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID guid) {
|
||||
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) {
|
||||
try {
|
||||
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
|
||||
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
|
||||
getMessageQueueMetadataKey(destinationUuid, destinationDevice),
|
||||
getQueueIndexKey(destinationUuid, destinationDevice)),
|
||||
List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
|
||||
List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8))));
|
||||
|
||||
if (serialized != null) {
|
||||
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
|
||||
|
@ -142,11 +154,11 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public List<OutgoingMessageEntity> get(String destination, final UUID destinationUuid, long destinationDevice, int limit) {
|
||||
public List<OutgoingMessageEntity> get(final String destination, final UUID destinationUuid, final long destinationDevice, final int limit) {
|
||||
return Metrics.timer(GET_TIMER_NAME).record(() -> {
|
||||
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
|
||||
getPersistInProgressKey(destinationUuid, destinationDevice)),
|
||||
List.of(String.valueOf(limit).getBytes()));
|
||||
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final List<OutgoingMessageEntity> messageEntities;
|
||||
|
||||
|
@ -172,6 +184,35 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
});
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@VisibleForTesting
|
||||
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final long destinationDevice, final int limit) {
|
||||
return Metrics.timer(GET_TIMER_NAME).record(() -> {
|
||||
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(accountUuid, destinationDevice),
|
||||
getPersistInProgressKey(accountUuid, destinationDevice)),
|
||||
List.of(String.valueOf(limit).getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final List<MessageProtos.Envelope> envelopes;
|
||||
|
||||
if (queueItems.size() % 2 == 0) {
|
||||
envelopes = new ArrayList<>(queueItems.size() / 2);
|
||||
|
||||
for (int i = 0; i < queueItems.size(); i += 2) {
|
||||
try {
|
||||
envelopes.add(MessageProtos.Envelope.parseFrom(queueItems.get(i)));
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
logger.warn("Failed to parse envelope", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
|
||||
envelopes = Collections.emptyList();
|
||||
}
|
||||
|
||||
return envelopes;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String destination, final UUID destinationUuid) {
|
||||
// TODO Remove null check in a fully UUID-based world
|
||||
|
@ -191,7 +232,27 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
Collections.emptyList()));
|
||||
}
|
||||
|
||||
private static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
|
||||
int getNextSlotToPersist() {
|
||||
return (int)(redisCluster.withWriteCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT);
|
||||
}
|
||||
|
||||
List<String> getQueuesToPersist(final int slot, final Instant maxTime, final int limit) {
|
||||
//noinspection unchecked
|
||||
return (List<String>)getQueuesToPersistScript.execute(List.of(new String(getQueueIndexKey(slot), StandardCharsets.UTF_8)),
|
||||
List.of(String.valueOf(maxTime.toEpochMilli()),
|
||||
String.valueOf(limit)));
|
||||
}
|
||||
|
||||
void lockQueueForPersistence(final String queue) {
|
||||
redisCluster.useBinaryWriteCluster(connection -> connection.sync().setex(getPersistInProgressKey(queue), 30, LOCK_VALUE));
|
||||
}
|
||||
|
||||
void unlockQueueForPersistence(final String queue) {
|
||||
redisCluster.useBinaryWriteCluster(connection -> connection.sync().del(getPersistInProgressKey(queue)));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
|
||||
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
|
@ -199,11 +260,29 @@ public class RedisClusterMessagesCache implements UserMessagesCache {
|
|||
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
private byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
|
||||
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(accountUuid.toString() + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
|
||||
private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
|
||||
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
|
||||
}
|
||||
|
||||
private byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
|
||||
return ("user_queue_persisting::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
|
||||
private static byte[] getQueueIndexKey(final int slot) {
|
||||
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
|
||||
return getPersistInProgressKey(accountUuid + "::" + deviceId);
|
||||
}
|
||||
|
||||
private static byte[] getPersistInProgressKey(final String queueName) {
|
||||
return ("user_queue_persisting::{" + queueName + "}").getBytes(StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
static UUID getAccountUuidFromQueueName(final String queueName) {
|
||||
final int startOfHashTag = queueName.indexOf('{');
|
||||
|
||||
return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag)));
|
||||
}
|
||||
|
||||
static long getDeviceIdFromQueueName(final String queueName) {
|
||||
return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import io.lettuce.core.cluster.SlotHash;
|
||||
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
|
||||
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
|
||||
|
||||
public class RedisClusterUtil {
|
||||
|
||||
|
@ -23,15 +21,7 @@ public class RedisClusterUtil {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
|
||||
*
|
||||
* @param key the key for which to find a matching hash tag
|
||||
* @return a Redis hash tag that maps to the same Redis cluster slot as the given key
|
||||
*
|
||||
* @see <a href="https://redis.io/topics/cluster-spec#keys-hash-tags">Redis Cluster Specification - Keys hash tags</a>
|
||||
*/
|
||||
public static String getMinimalHashTag(final String key) {
|
||||
return HASHES_BY_SLOT[SlotHash.getSlot(key)];
|
||||
public static String getMinimalHashTag(final int slot) {
|
||||
return HASHES_BY_SLOT[slot];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import com.google.protobuf.ByteString;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.push.PushSender;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class RedisClusterMessagePersisterTest {
|
||||
|
||||
private RedisClusterMessagesCache messagesCache;
|
||||
private Messages messagesDatabase;
|
||||
private RedisClusterMessagePersister messagePersister;
|
||||
private AccountsManager accountsManager;
|
||||
|
||||
private long serialTimestamp = 0;
|
||||
|
||||
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
|
||||
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
|
||||
private static final long DESTINATION_DEVICE_ID = 7;
|
||||
|
||||
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
|
||||
|
||||
private static final Random RANDOM = new Random();
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
messagesCache = mock(RedisClusterMessagesCache.class);
|
||||
messagesDatabase = mock(Messages.class);
|
||||
accountsManager = mock(AccountsManager.class);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
when(accountsManager.get(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
|
||||
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
|
||||
|
||||
messagePersister = new RedisClusterMessagePersister(messagesCache, messagesDatabase, mock(PubSubManager.class), mock(PushSender.class), accountsManager, PERSIST_DELAY);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistNextQueuesNoQueues() {
|
||||
final int slot = 7;
|
||||
|
||||
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
|
||||
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(Collections.emptyList());
|
||||
|
||||
messagePersister.persistNextQueues(Instant.now());
|
||||
|
||||
verify(messagesCache, never()).lockQueueForPersistence(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistNextQueuesSingleQueue() {
|
||||
final int slot = 7;
|
||||
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||
|
||||
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
|
||||
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))).thenReturn(List.of(queueName));
|
||||
|
||||
messagePersister.persistNextQueues(Instant.now());
|
||||
|
||||
verify(messagesCache).lockQueueForPersistence(queueName);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistNextQueuesMultiplePages() {
|
||||
final int slot = 7;
|
||||
final int queueCount = RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3;
|
||||
|
||||
final List<String> queues = new ArrayList<>(queueCount);
|
||||
|
||||
for (int i = 0; i < queueCount; i++) {
|
||||
final String queueName = generateRandomQueueName();
|
||||
final UUID accountUuid = RedisClusterMessagesCache.getAccountUuidFromQueueName(queueName);
|
||||
|
||||
queues.add(queueName);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account));
|
||||
when(account.getNumber()).thenReturn("+1" + RandomStringUtils.randomNumeric(10));
|
||||
}
|
||||
|
||||
when(messagesCache.getNextSlotToPersist()).thenReturn(slot);
|
||||
when(messagesCache.getQueuesToPersist(eq(slot), any(Instant.class), eq(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT)))
|
||||
.thenReturn(queues.subList(0, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT))
|
||||
.thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2))
|
||||
.thenReturn(queues.subList(RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 2, RedisClusterMessagePersister.QUEUE_BATCH_LIMIT * 3))
|
||||
.thenReturn(Collections.emptyList());
|
||||
|
||||
messagePersister.persistNextQueues(Instant.now());
|
||||
|
||||
verify(messagesCache, times(queueCount)).lockQueueForPersistence(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistQueueNoMessages() {
|
||||
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||
|
||||
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(Collections.emptyList());
|
||||
|
||||
messagePersister.persistQueue(queueName);
|
||||
|
||||
verify(messagesCache).lockQueueForPersistence(queueName);
|
||||
verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
|
||||
verify(messagesDatabase, never()).store(any(), any(), any(), anyLong());
|
||||
verify(messagesCache, never()).remove(anyString(), any(UUID.class), anyLong(), any(UUID.class));
|
||||
verify(messagesCache).unlockQueueForPersistence(queueName);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistQueueSingleMessage() {
|
||||
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||
|
||||
final MessageProtos.Envelope message = generateRandomMessage();
|
||||
|
||||
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT)).thenReturn(List.of(message));
|
||||
|
||||
messagePersister.persistQueue(queueName);
|
||||
|
||||
verify(messagesCache).lockQueueForPersistence(queueName);
|
||||
verify(messagesCache).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
|
||||
verify(messagesDatabase).store(UUID.fromString(message.getServerGuid()), message, DESTINATION_ACCOUNT_NUMBER, DESTINATION_DEVICE_ID);
|
||||
verify(messagesCache).remove(DESTINATION_ACCOUNT_NUMBER, DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, UUID.fromString(message.getServerGuid()));
|
||||
verify(messagesCache).unlockQueueForPersistence(queueName);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPersistQueueMultiplePages() {
|
||||
final int messageCount = RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3;
|
||||
final List<MessageProtos.Envelope> messagesInQueue = new ArrayList<>(messageCount);
|
||||
|
||||
for (int i = 0; i < messageCount; i++) {
|
||||
messagesInQueue.add(generateRandomMessage());
|
||||
}
|
||||
|
||||
when(messagesCache.getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT))
|
||||
.thenReturn(messagesInQueue.subList(0, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT))
|
||||
.thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2))
|
||||
.thenReturn(messagesInQueue.subList(RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 2, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT * 3))
|
||||
.thenReturn(Collections.emptyList());
|
||||
|
||||
final String queueName = new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||
|
||||
messagePersister.persistQueue(queueName);
|
||||
|
||||
verify(messagesCache).lockQueueForPersistence(queueName);
|
||||
verify(messagesCache, times(4)).getMessagesToPersist(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, RedisClusterMessagePersister.MESSAGE_BATCH_LIMIT);
|
||||
verify(messagesDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID));
|
||||
verify(messagesCache, times(messageCount)).remove(eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID), any(UUID.class));
|
||||
verify(messagesCache).unlockQueueForPersistence(queueName);
|
||||
}
|
||||
|
||||
private MessageProtos.Envelope generateRandomMessage() {
|
||||
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
|
||||
.setTimestamp(serialTimestamp++)
|
||||
.setServerTimestamp(serialTimestamp++)
|
||||
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
|
||||
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
|
||||
.setServerGuid(UUID.randomUUID().toString());
|
||||
|
||||
return envelopeBuilder.build();
|
||||
}
|
||||
|
||||
private String generateRandomQueueName() {
|
||||
return String.format("user_queue::{%s::%d}", UUID.randomUUID().toString(), RANDOM.nextInt(10));
|
||||
}
|
||||
}
|
|
@ -1,13 +1,18 @@
|
|||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import io.lettuce.core.cluster.SlotHash;
|
||||
import junitparams.Parameters;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
|
||||
|
||||
|
@ -52,4 +57,33 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
|
|||
// We're happy as long as this doesn't throw an exception
|
||||
messagesCache.clear(DESTINATION_ACCOUNT, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetAccountFromQueueName() {
|
||||
assertEquals(DESTINATION_UUID,
|
||||
RedisClusterMessagesCache.getAccountUuidFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8)));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetDeviceIdFromQueueName() {
|
||||
assertEquals(DESTINATION_DEVICE_ID,
|
||||
RedisClusterMessagesCache.getDeviceIdFromQueueName(new String(RedisClusterMessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8)));
|
||||
}
|
||||
|
||||
@Test
|
||||
@Parameters({"true", "false"})
|
||||
public void testGetQueuesToPersist(final boolean sealedSender) {
|
||||
final UUID messageGuid = UUID.randomUUID();
|
||||
|
||||
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender));
|
||||
final int slot = SlotHash.getSlot(DESTINATION_UUID.toString() + "::" + DESTINATION_DEVICE_ID);
|
||||
|
||||
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
|
||||
|
||||
final List<String> queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100);
|
||||
|
||||
assertEquals(1, queues.size());
|
||||
assertEquals(DESTINATION_UUID, RedisClusterMessagesCache.getAccountUuidFromQueueName(queues.get(0)));
|
||||
assertEquals(DESTINATION_DEVICE_ID, RedisClusterMessagesCache.getDeviceIdFromQueueName(queues.get(0)));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue