Entirely discard the old message cache machinery.

This commit is contained in:
Jon Chambers 2020-08-27 14:24:45 -04:00 committed by Jon Chambers
parent 6061d0603a
commit 18ecd748dd
11 changed files with 163 additions and 742 deletions

View File

@ -68,9 +68,7 @@ directory:
reconciliationChunkIntervalMs: # CDS reconciliation chunk interval, in milliseconds
messageCache: # Redis server configuration for message store cache
redis:
url:
replicaUrls:
persistDelayMinutes:
cluster:
urls:

View File

@ -150,7 +150,6 @@ import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.CertificateCommand;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
import org.whispersystems.textsecuregcm.workers.ScourMessageCacheCommand;
import org.whispersystems.textsecuregcm.workers.VacuumCommand;
import org.whispersystems.textsecuregcm.workers.ZkParamsCommand;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@ -183,7 +182,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new DeleteUserCommand());
bootstrap.addCommand(new CertificateCommand());
bootstrap.addCommand(new ZkParamsCommand());
bootstrap.addCommand(new ScourMessageCacheCommand());
bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("accountdb", "accountsdb.xml") {
@Override

View File

@ -7,11 +7,6 @@ import javax.validation.constraints.NotNull;
public class MessageCacheConfiguration {
@JsonProperty
@NotNull
@Valid
private RedisConfiguration redis;
@JsonProperty
@NotNull
@Valid
@ -20,10 +15,6 @@ public class MessageCacheConfiguration {
@JsonProperty
private int persistDelayMinutes = 10;
public RedisConfiguration getRedisConfiguration() {
return redis;
}
public RedisClusterConfiguration getRedisClusterConfiguration() {
return cluster;
}

View File

@ -1,270 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import redis.clients.jedis.Jedis;
import redis.clients.util.SafeEncoder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class MessagesCache implements UserMessagesCache {
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" ));
private static final Timer removeByIdTimer = metricRegistry.timer(name(MessagesCache.class, "removeById" ));
private static final Timer removeByNameTimer = metricRegistry.timer(name(MessagesCache.class, "removeByName"));
private static final Timer removeByGuidTimer = metricRegistry.timer(name(MessagesCache.class, "removeByGuid"));
private static final Timer getTimer = metricRegistry.timer(name(MessagesCache.class, "get" ));
private static final Timer clearAccountTimer = metricRegistry.timer(name(MessagesCache.class, "clearAccount"));
private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" ));
private final ReplicatedJedisPool jedisPool;
private final InsertOperation insertOperation;
private final RemoveOperation removeOperation;
private final GetOperation getOperation;
public MessagesCache(ReplicatedJedisPool jedisPool) throws IOException {
this.jedisPool = jedisPool;
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
}
@Override
public long insert(UUID guid, String destination, final UUID destinationUuid, long destinationDevice, Envelope message) {
final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
Timer.Context timer = insertTimer.time();
try {
return insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid);
} finally {
timer.stop();
}
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, final UUID destinationUuid, long destinationDevice, long id) {
OutgoingMessageEntity removedMessageEntity = null;
try (Jedis jedis = jedisPool.getWriteResource();
Timer.Context ignored = removeByIdTimer.time())
{
byte[] serialized = removeOperation.remove(jedis, destination, destinationDevice, id);
if (serialized != null) {
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(id, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.ofNullable(removedMessageEntity);
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, final UUID destinationUuid, long destinationDevice, String sender, long timestamp) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByNameTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
if (serialized != null) {
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
} finally {
timer.stop();
}
return Optional.ofNullable(removedMessageEntity);
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, final UUID destinationUuid, long destinationDevice, UUID guid) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByGuidTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, guid);
if (serialized != null) {
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
} finally {
timer.stop();
}
return Optional.ofNullable(removedMessageEntity);
}
@Override
public List<OutgoingMessageEntity> get(String destination, final UUID destinationUuid, long destinationDevice, int limit) {
Timer.Context timer = getTimer.time();
try {
List<OutgoingMessageEntity> results = new LinkedList<>();
Key key = new Key(destination, destinationDevice);
List<Pair<byte[], Double>> items = getOperation.getItems(key.getUserMessageQueue(), key.getUserMessageQueuePersistInProgress(), limit);
for (Pair<byte[], Double> item : items) {
try {
long id = item.second().longValue();
Envelope message = Envelope.parseFrom(item.first());
results.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return results;
} finally {
timer.stop();
}
}
@Override
public void clear(String destination, final UUID destinationUuid) {
Timer.Context timer = clearAccountTimer.time();
try {
for (int i = 1; i < 255; i++) {
clear(destination, destinationUuid, i);
}
} finally {
timer.stop();
}
}
@Override
public void clear(String destination, final UUID destinationUuid, long deviceId) {
Timer.Context timer = clearDeviceTimer.time();
try {
removeOperation.clear(destination, deviceId);
} finally {
timer.stop();
}
}
private static class InsertOperation {
private final LuaScript insert;
InsertOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
}
public long insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
Key key = new Key(destination, destinationDevice);
String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes(), guid.toString().getBytes());
return (long)insert.execute(keys, args);
}
}
private static class RemoveOperation {
private final LuaScript removeById;
private final LuaScript removeBySender;
private final LuaScript removeByGuid;
private final LuaScript removeQueue;
RemoveOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.removeById = LuaScript.fromResource(jedisPool, "lua/remove_item_by_id.lua" );
this.removeBySender = LuaScript.fromResource(jedisPool, "lua/remove_item_by_sender.lua");
this.removeByGuid = LuaScript.fromResource(jedisPool, "lua/remove_item_by_guid.lua" );
this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
}
public byte[] remove(Jedis jedis, String destination, long destinationDevice, long id) {
Key key = new Key(destination, destinationDevice);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(String.valueOf(id).getBytes());
return (byte[])this.removeById.execute(jedis, keys, args);
}
public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {
Key key = new Key(destination, destinationDevice);
String senderKey = sender + "::" + timestamp;
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(senderKey.getBytes());
return (byte[])this.removeBySender.execute(keys, args);
}
public byte[] remove(String destination, long destinationDevice, UUID guid) {
Key key = new Key(destination, destinationDevice);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(guid.toString().getBytes());
return (byte[])this.removeByGuid.execute(keys, args);
}
public void clear(String destination, long deviceId) {
Key key = new Key(destination, deviceId);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = new LinkedList<>();
this.removeQueue.execute(keys, args);
}
}
private static class GetOperation {
private final LuaScript getItems;
GetOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua");
}
List<Pair<byte[], Double>> getItems(byte[] queue, byte[] lock, int limit) {
List<byte[]> keys = Arrays.asList(queue, lock);
List<byte[]> args = Collections.singletonList(String.valueOf(limit).getBytes());
Iterator<byte[]> results = ((List<byte[]>) getItems.execute(keys, args)).iterator();
List<Pair<byte[], Double>> items = new LinkedList<>();
while (results.hasNext()) {
items.add(new Pair<>(results.next(), Double.valueOf(SafeEncoder.encode(results.next()))));
}
return items;
}
}
}

View File

@ -37,7 +37,7 @@ import java.util.concurrent.ExecutorService;
import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String, String> implements UserMessagesCache, Managed {
public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String, String> implements Managed {
private final FaultTolerantRedisCluster redisCluster;
private final FaultTolerantPubSubConnection<String, String> pubSubConnection;
@ -123,7 +123,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
}
}
@Override
public long insert(final UUID guid, final String destination, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
@ -138,7 +137,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
guid.toString().getBytes(StandardCharsets.UTF_8))));
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final long id) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() ->
@ -149,7 +147,7 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized)));
return Optional.of(constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@ -158,7 +156,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final String sender, final long timestamp) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() ->
@ -168,7 +165,7 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
return Optional.of(constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@ -177,7 +174,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
@ -187,7 +183,7 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
List.of(messageGuid.toString().getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
return Optional.of(constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@ -196,7 +192,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
return Optional.empty();
}
@Override
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(final String destination, final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> {
@ -214,7 +209,7 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8));
messageEntities.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
messageEntities.add(constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
@ -257,7 +252,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
});
}
@Override
public void clear(final String destination, final UUID destinationUuid) {
// TODO Remove null check in a fully UUID-based world
if (destinationUuid != null) {
@ -267,7 +261,6 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
}
}
@Override
public void clear(final String destination, final UUID destinationUuid, final long deviceId) {
clearQueueTimer.record(() ->
removeQueueScript.executeBinary(List.of(getMessageQueueKey(destinationUuid, deviceId),
@ -354,6 +347,21 @@ public class RedisClusterMessagesCache extends RedisClusterPubSubAdapter<String,
}
}
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
@VisibleForTesting
static String getQueueName(final UUID accountUuid, final long deviceId) {
return accountUuid + "::" + deviceId;

View File

@ -1,40 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface UserMessagesCache {
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
long insert(UUID guid, String destination, UUID destinationUuid, long destinationDevice, MessageProtos.Envelope message);
Optional<OutgoingMessageEntity> remove(String destination, UUID destinationUuid, long destinationDevice, long id);
Optional<OutgoingMessageEntity> remove(String destination, UUID destinationUuid, long destinationDevice, String sender, long timestamp);
Optional<OutgoingMessageEntity> remove(String destination, UUID destinationUuid, long destinationDevice, UUID guid);
List<OutgoingMessageEntity> get(String destination, UUID destinationUuid, long destinationDevice, int limit);
void clear(String destination, UUID destinationUuid);
void clear(String destination, UUID destinationUuid, long deviceId);
}

View File

@ -1,116 +0,0 @@
package org.whispersystems.textsecuregcm.workers;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.cli.ConfiguredCommand;
import io.dropwizard.setup.Bootstrap;
import io.lettuce.core.ScanArgs;
import io.lettuce.core.ScanIterator;
import io.lettuce.core.ScoredValue;
import net.sourceforge.argparse4j.inf.Namespace;
import org.jdbi.v3.core.Jdbi;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.Messages;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.UUID;
public class ScourMessageCacheCommand extends ConfiguredCommand<WhisperServerConfiguration> {
private FaultTolerantRedisClient redisClient;
private Messages messageDatabase;
private static final int MESSAGE_PAGE_SIZE = 100;
private static final Logger log = LoggerFactory.getLogger(ScourMessageCacheCommand.class);
public ScourMessageCacheCommand() {
super("scourmessagecache", "Persist and remove all message queues from the old message cache");
}
@SuppressWarnings("ConstantConditions")
@Override
protected void run(final Bootstrap<WhisperServerConfiguration> bootstrap, final Namespace namespace, final WhisperServerConfiguration config) {
final DatabaseConfiguration messageDbConfig = config.getMessageStoreConfiguration();
final Jdbi messageJdbi = Jdbi.create(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword());
final FaultTolerantDatabase messageDatabase = new FaultTolerantDatabase("message_database", messageJdbi, config.getMessageStoreConfiguration().getCircuitBreakerConfiguration());
this.setMessageDatabase(new Messages(messageDatabase));
this.setRedisClient(new FaultTolerantRedisClient("scourMessageCacheClient", config.getMessageCacheConfiguration().getRedisConfiguration()));
scourMessageCache();
}
@VisibleForTesting
void setRedisClient(final FaultTolerantRedisClient redisClient) {
this.redisClient = redisClient;
}
@VisibleForTesting
void setMessageDatabase(final Messages messageDatabase) {
this.messageDatabase = messageDatabase;
}
@VisibleForTesting
void scourMessageCache() {
redisClient.useClient(connection -> ScanIterator.scan(connection.sync(), ScanArgs.Builder.matches("user_queue::*"))
.stream()
.forEach(this::persistQueue));
}
@VisibleForTesting
void persistQueue(final String queueKey) {
final String accountNumber;
{
final int startOfAccountNumber = queueKey.indexOf("::");
accountNumber = queueKey.substring(startOfAccountNumber + 2, queueKey.indexOf("::", startOfAccountNumber + 1));
}
final long deviceId = Long.parseLong(queueKey.substring(queueKey.lastIndexOf("::") + 2));
final byte[] queueKeyBytes = queueKey.getBytes(StandardCharsets.UTF_8);
int messageCount = 0;
List<ScoredValue<byte[]>> messages;
do {
final int start = messageCount;
messages = redisClient.withBinaryClient(connection -> connection.sync().zrangeWithScores(queueKeyBytes, start, start + MESSAGE_PAGE_SIZE));
for (final ScoredValue<byte[]> scoredValue : messages) {
persistMessage(accountNumber, deviceId, scoredValue.getValue());
messageCount++;
}
} while (!messages.isEmpty());
redisClient.useClient(connection -> {
final String accountNumberAndDeviceId = accountNumber + "::" + deviceId;
connection.async().del("user_queue::" + accountNumberAndDeviceId,
"user_queue_metadata::" + accountNumberAndDeviceId,
"user_queue_persisting::" + accountNumberAndDeviceId);
});
log.info("Persisted a queue with {} messages", messageCount);
}
private void persistMessage(final String accountNumber, final long deviceId, final byte[] message) {
try {
MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(message);
UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null;
envelope = envelope.toBuilder().clearServerGuid().build();
messageDatabase.store(guid, envelope, accountNumber, deviceId);
} catch (InvalidProtocolBufferException e) {
log.error("Error parsing envelope", e);
}
}
}

View File

@ -1,158 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(JUnitParamsRunner.class)
public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest {
private static final String DESTINATION_ACCOUNT = "+18005551234";
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7;
private final Random random = new Random();
private long serialTimestamp = 0;
protected abstract UserMessagesCache getMessagesCache();
@Test
@Parameters({"true", "false"})
public void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0);
}
@Test
@Parameters({"true", "false"})
public void testRemoveById(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId));
}
@Test
public void testRemoveBySender() {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp());
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()));
}
@Test
@Parameters({"true", "false"})
public void testRemoveByUUID(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid));
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
}
@Test
@Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) {
final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
expectedMessages.add(UserMessagesCache.constructEntityFromEnvelope(messageId, message));
}
assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message);
}
}
getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID);
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message);
}
}
getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_UUID);
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount));
}
protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(serialTimestamp++)
.setServerTimestamp(serialTimestamp++)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(256))
.setSource("+1" + RandomStringUtils.randomNumeric(10));
}
return envelopeBuilder.build();
}
}

View File

@ -1,41 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import org.junit.After;
import org.junit.Before;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import redis.embedded.RedisServer;
import java.util.List;
import static org.mockito.Mockito.mock;
public class MessagesCacheTest extends AbstractMessagesCacheTest {
private RedisServer redisServer;
private MessagesCache messagesCache;
@Before
public void setUp() throws Exception {
redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort());
redisServer.start();
final String redisUrl = String.format("redis://127.0.0.1:%d", redisServer.ports().get(0));
final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration());
final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool();
messagesCache = new MessagesCache(jedisPool);
}
@After
public void tearDown() {
redisServer.stop();
}
@Override
protected UserMessagesCache getMessagesCache() {
return messagesCache;
}
}

View File

@ -1,13 +1,24 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import io.lettuce.core.cluster.SlotHash;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@ -17,15 +28,19 @@ import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
@RunWith(JUnitParamsRunner.class)
public class RedisClusterMessagesCacheTest extends AbstractRedisClusterTest {
private ExecutorService notificationExecutorService;
private RedisClusterMessagesCache messagesCache;
private final Random random = new Random();
private long serialTimestamp = 0;
private static final String DESTINATION_ACCOUNT = "+18005551234";
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7;
private ExecutorService notificationExecutorService;
private RedisClusterMessagesCache messagesCache;
@Override
@Before
public void setUp() throws Exception {
@ -49,9 +64,129 @@ public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
super.tearDown();
}
@Override
protected UserMessagesCache getMessagesCache() {
return messagesCache;
@Test
@Parameters({"true", "false"})
public void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertTrue(messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0);
}
@Test
@Parameters({"true", "false"})
public void testRemoveById(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(RedisClusterMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageId));
}
@Test
public void testRemoveBySender() {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false);
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp());
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(RedisClusterMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()));
}
@Test
@Parameters({"true", "false"})
public void testRemoveByUUID(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid));
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = messagesCache.remove(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(RedisClusterMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
}
@Test
@Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) {
final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
expectedMessages.add(RedisClusterMessagesCache.constructEntityFromEnvelope(messageId, message));
}
assertEquals(expectedMessages, messagesCache.get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message);
}
}
messagesCache.clear(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID);
assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(messageCount, messagesCache.get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_UUID, deviceId, message);
}
}
messagesCache.clear(DESTINATION_ACCOUNT, DESTINATION_UUID);
assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_ACCOUNT, DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount));
}
protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(serialTimestamp++)
.setServerTimestamp(serialTimestamp++)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(256))
.setSource("+1" + RandomStringUtils.randomNumeric(10));
}
return envelopeBuilder.build();
}
@Test

View File

@ -1,84 +0,0 @@
package org.whispersystems.textsecuregcm.workers;
import com.google.protobuf.ByteString;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.AbstractRedisSingletonTest;
import org.whispersystems.textsecuregcm.storage.Messages;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import java.util.Random;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class ScourMessageCacheCommandTest extends AbstractRedisSingletonTest {
private Messages messageDatabase;
private MessagesCache messagesCache;
private ScourMessageCacheCommand scourMessageCacheCommand;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
messageDatabase = mock(Messages.class);
messagesCache = new MessagesCache(getJedisPool());
scourMessageCacheCommand = new ScourMessageCacheCommand();
scourMessageCacheCommand.setMessageDatabase(messageDatabase);
scourMessageCacheCommand.setRedisClient(getRedisClient());
}
@Test
public void testScourMessageCache() {
final int messageCount = insertDetachedMessages(100, 1_000);
scourMessageCacheCommand.scourMessageCache();
verify(messageDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong());
assertEquals(0, (long)getRedisClient().withClient(connection -> connection.sync().dbsize()));
}
@SuppressWarnings("SameParameterValue")
private int insertDetachedMessages(final int accounts, final int maxMessagesPerAccount) {
int totalMessages = 0;
final Random random = new Random();
for (int i = 0; i < accounts; i++) {
final String accountNumber = String.format("+1800%07d", i);
final UUID accountUuid = UUID.randomUUID();
final int messageCount = random.nextInt(maxMessagesPerAccount);
for (int j = 0; j < messageCount; j++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountNumber, accountUuid, 1, envelope);
}
totalMessages += messageCount;
}
getRedisClient().useClient(connection -> connection.sync().del("user_queue_index"));
return totalMessages;
}
}