Factor MessagePersister into its own class.

This commit is contained in:
Jon Chambers 2020-08-06 17:38:06 -04:00 committed by Jon Chambers
parent e35e34d2e0
commit 5fad8f74b1
5 changed files with 280 additions and 280 deletions

View File

@ -128,6 +128,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryReconciler;
import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.MessagePersister;
import org.whispersystems.textsecuregcm.storage.Messages;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -344,7 +345,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
UsernamesManager usernamesManager = new UsernamesManager(usernames, reservedUsernames, cacheCluster);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
RedisClusterMessagesCache clusterMessagesCache = new RedisClusterMessagesCache(messagesCacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes());
MessagesCache messagesCache = new MessagesCache(messagesClient);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster);
MessagesManager messagesManager = new MessagesManager(messages, messagesCache, clusterMessagesCache, pushLatencyManager, messageCacheClusterExperimentExecutor);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
@ -374,6 +375,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(config.getTurnConfiguration());
RecaptchaClient recaptchaClient = new RecaptchaClient(config.getRecaptchaConfiguration().getSecret());
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesClient, messages, pubSubManager, pushSender, accountsManager,config.getMessageCacheConfiguration().getPersistDelayMinutes(), TimeUnit.MINUTES);
RedisClusterMessagePersister clusterMessagePersister = new RedisClusterMessagePersister(clusterMessagesCache, messages, pubSubManager, pushSender, accountsManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
DirectoryReconciliationClient directoryReconciliationClient = new DirectoryReconciliationClient(config.getDirectoryConfiguration().getDirectoryServerConfiguration());
@ -388,13 +390,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountDatabaseCrawlerCache accountDatabaseCrawlerCache = new AccountDatabaseCrawlerCache(cacheCluster);
AccountDatabaseCrawler accountDatabaseCrawler = new AccountDatabaseCrawler(accountsManager, accountDatabaseCrawlerCache, accountDatabaseCrawlerListeners, config.getAccountDatabaseCrawlerConfiguration().getChunkSize(), config.getAccountDatabaseCrawlerConfiguration().getChunkIntervalMs());
messagesCache.setPubSubManager(pubSubManager, pushSender);
apnSender.setApnFallbackManager(apnFallbackManager);
environment.lifecycle().manage(apnFallbackManager);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(pushSender);
environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(messagePersister);
environment.lifecycle().manage(accountDatabaseCrawler);
environment.lifecycle().manage(remoteConfigsManager);
environment.lifecycle().manage(clusterMessagePersister);

View File

@ -0,0 +1,59 @@
package org.whispersystems.textsecuregcm.storage;
import java.io.IOException;
class Key {
private final byte[] userMessageQueue;
private final byte[] userMessageQueueMetadata;
private final byte[] userMessageQueuePersistInProgress;
private final String address;
private final long deviceId;
Key(String address, long deviceId) {
this.address = address;
this.deviceId = deviceId;
this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes();
this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes();
this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes();
}
String getAddress() {
return address;
}
long getDeviceId() {
return deviceId;
}
byte[] getUserMessageQueue() {
return userMessageQueue;
}
byte[] getUserMessageQueueMetadata() {
return userMessageQueueMetadata;
}
byte[] getUserMessageQueuePersistInProgress() {
return userMessageQueuePersistInProgress;
}
static byte[] getUserMessageQueueIndex() {
return "user_queue_index".getBytes();
}
static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException {
try {
String[] parts = new String(userMessageQueue).split("::");
if (parts.length != 3) {
throw new IOException("Malformed key: " + new String(userMessageQueue));
}
return new Key(parts[1], Long.parseLong(parts[2]));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
}

View File

@ -0,0 +1,214 @@
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
public class MessagePersister implements Managed, Runnable {
private final MessagesCache messagesCache;
private final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" ));
private final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue"));
private final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" ));
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" ));
private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount"));
private static final int CHUNK_SIZE = 100;
private final AtomicBoolean running = new AtomicBoolean(true);
private final ReplicatedJedisPool jedisPool;
private final Messages database;
private final long delayTime;
private final TimeUnit delayTimeUnit;
private final PubSubManager pubSubManager;
private final PushSender pushSender;
private final AccountsManager accountsManager;
private final LuaScript getQueuesScript;
private boolean finished = false;
public MessagePersister(final MessagesCache messagesCache,
final ReplicatedJedisPool jedisPool,
final Messages database,
final PubSubManager pubSubManager,
final PushSender pushSender,
final AccountsManager accountsManager,
final long delayTime,
final TimeUnit delayTimeUnit)
throws IOException
{
this.messagesCache = messagesCache;
this.jedisPool = jedisPool;
this.database = database;
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.delayTime = delayTime;
this.delayTimeUnit = delayTimeUnit;
this.getQueuesScript = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua");
}
@Override
public void start() {
new Thread(this, getClass().getSimpleName()).start();
}
@Override
public void run() {
while (running.get()) {
try {
List<byte[]> queuesToPersist = getQueuesToPersist();
queueCountHistogram.update(queuesToPersist.size());
for (byte[] queue : queuesToPersist) {
Key key = Key.fromUserMessageQueue(queue);
persistQueue(jedisPool, key);
notifyClients(accountsManager, pubSubManager, pushSender, key);
}
if (queuesToPersist.isEmpty()) {
//noinspection BusyWait
Thread.sleep(10_000);
}
} catch (Throwable t) {
logger.error("Exception while persisting: ", t);
}
}
synchronized (this) {
finished = true;
notifyAll();
}
}
@Override
public synchronized void stop() {
running.set(false);
while (!finished) Util.wait(this);
logger.info("Message persister shut down...");
}
private void persistQueue(ReplicatedJedisPool jedisPool, Key key) {
Timer.Context timer = persistQueueTimer.time();
int messagesPersistedCount = 0;
UUID destinationUuid = accountsManager.get(key.getAddress()).map(Account::getUuid).orElse(null);
try (Jedis jedis = jedisPool.getWriteResource()) {
while (true) {
jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes());
Set<Tuple> messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE);
for (Tuple message : messages) {
persistMessage(key, destinationUuid, (long)message.getScore(), message.getBinaryElement());
messagesPersistedCount++;
}
if (messages.size() < CHUNK_SIZE) {
jedis.del(key.getUserMessageQueuePersistInProgress());
return;
}
}
} finally {
timer.stop();
queueSizeHistogram.update(messagesPersistedCount);
}
}
private void persistMessage(Key key, UUID destinationUuid, long score, byte[] message) {
try {
MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(message);
UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null;
envelope = envelope.toBuilder().clearServerGuid().build();
database.store(guid, envelope, key.getAddress(), key.getDeviceId());
} catch (InvalidProtocolBufferException e) {
logger.error("Error parsing envelope", e);
}
messagesCache.remove(key.getAddress(), destinationUuid, key.getDeviceId(), score);
}
private List<byte[]> getQueuesToPersist() {
Timer.Context timer = getQueuesTimer.time();
try {
long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime);
List<byte[]> keys = Collections.singletonList(Key.getUserMessageQueueIndex());
List<byte[]> args = Arrays.asList(String.valueOf(maxTime).getBytes(), String.valueOf(100).getBytes());
//noinspection unchecked
return (List<byte[]>)getQueuesScript.execute(keys, args);
} finally {
timer.stop();
}
}
private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key) {
Timer.Context timer = notifyTimer.time();
try {
boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build());
if (!notified) {
Optional<Account> account = accountsManager.get(key.getAddress());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(key.getDeviceId());
if (device.isPresent()) {
try {
pushSender.sendQueuedNotification(account.get(), device.get());
} catch (NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!");
}
}
}
}
} finally {
timer.stop();
}
}
}

View File

@ -1,26 +1,18 @@
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
import java.io.IOException;
@ -30,17 +22,11 @@ import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
public class MessagesCache implements Managed, UserMessagesCache {
public class MessagesCache implements UserMessagesCache {
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
@ -54,23 +40,13 @@ public class MessagesCache implements Managed, UserMessagesCache {
private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" ));
private final ReplicatedJedisPool jedisPool;
private final Messages database;
private final AccountsManager accountsManager;
private final int delayMinutes;
private final InsertOperation insertOperation;
private final RemoveOperation removeOperation;
private final GetOperation getOperation;
private PubSubManager pubSubManager;
private PushSender pushSender;
private MessagePersister messagePersister;
public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) throws IOException {
public MessagesCache(ReplicatedJedisPool jedisPool) throws IOException {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
@ -198,79 +174,6 @@ public class MessagesCache implements Managed, UserMessagesCache {
}
}
public void setPubSubManager(PubSubManager pubSubManager, PushSender pushSender) {
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
}
@Override
public void start() throws Exception {
this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
this.messagePersister.start();
}
@Override
public void stop() throws Exception {
messagePersister.shutdown();
logger.info("Message persister shut down...");
}
private static class Key {
private final byte[] userMessageQueue;
private final byte[] userMessageQueueMetadata;
private final byte[] userMessageQueuePersistInProgress;
private final String address;
private final long deviceId;
Key(String address, long deviceId) {
this.address = address;
this.deviceId = deviceId;
this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes();
this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes();
this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes();
}
String getAddress() {
return address;
}
long getDeviceId() {
return deviceId;
}
byte[] getUserMessageQueue() {
return userMessageQueue;
}
byte[] getUserMessageQueueMetadata() {
return userMessageQueueMetadata;
}
byte[] getUserMessageQueuePersistInProgress() {
return userMessageQueuePersistInProgress;
}
static byte[] getUserMessageQueueIndex() {
return "user_queue_index".getBytes();
}
static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException {
try {
String[] parts = new String(userMessageQueue).split("::");
if (parts.length != 3) {
throw new IOException("Malformed key: " + new String(userMessageQueue));
}
return new Key(parts[1], Long.parseLong(parts[2]));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
}
private static class InsertOperation {
private final LuaScript insert;
@ -343,21 +246,12 @@ public class MessagesCache implements Managed, UserMessagesCache {
private static class GetOperation {
private final LuaScript getQueues;
private final LuaScript getItems;
GetOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.getQueues = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua");
this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua");
}
List<byte[]> getQueues(byte[] queue, long maxTimeMillis, int limit) {
List<byte[]> keys = Collections.singletonList(queue);
List<byte[]> args = Arrays.asList(String.valueOf(maxTimeMillis).getBytes(), String.valueOf(limit).getBytes());
return (List<byte[]>)getQueues.execute(keys, args);
}
List<Pair<byte[], Double>> getItems(byte[] queue, byte[] lock, int limit) {
List<byte[]> keys = Arrays.asList(queue, lock);
List<byte[]> args = Collections.singletonList(String.valueOf(limit).getBytes());
@ -373,171 +267,4 @@ public class MessagesCache implements Managed, UserMessagesCache {
}
}
private class MessagePersister extends Thread {
private final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" ));
private final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue"));
private final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" ));
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" ));
private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount"));
private static final int CHUNK_SIZE = 100;
private final AtomicBoolean running = new AtomicBoolean(true);
private final ReplicatedJedisPool jedisPool;
private final Messages database;
private final long delayTime;
private final TimeUnit delayTimeUnit;
private final PubSubManager pubSubManager;
private final PushSender pushSender;
private final AccountsManager accountsManager;
private final GetOperation getOperation;
private final RemoveOperation removeOperation;
private boolean finished = false;
MessagePersister(ReplicatedJedisPool jedisPool,
Messages database,
PubSubManager pubSubManager,
PushSender pushSender,
AccountsManager accountsManager,
long delayTime,
TimeUnit delayTimeUnit)
throws IOException
{
super(MessagePersister.class.getSimpleName());
this.jedisPool = jedisPool;
this.database = database;
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.delayTime = delayTime;
this.delayTimeUnit = delayTimeUnit;
this.getOperation = new GetOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
}
@Override
public void run() {
while (running.get()) {
try {
List<byte[]> queuesToPersist = getQueuesToPersist(getOperation);
queueCountHistogram.update(queuesToPersist.size());
for (byte[] queue : queuesToPersist) {
Key key = Key.fromUserMessageQueue(queue);
persistQueue(jedisPool, key);
notifyClients(accountsManager, pubSubManager, pushSender, key);
}
if (queuesToPersist.isEmpty()) {
Thread.sleep(10000);
}
} catch (Throwable t) {
logger.error("Exception while persisting: ", t);
}
}
synchronized (this) {
finished = true;
notifyAll();
}
}
synchronized void shutdown() {
running.set(false);
while (!finished) Util.wait(this);
}
private void persistQueue(ReplicatedJedisPool jedisPool, Key key) throws IOException {
Timer.Context timer = persistQueueTimer.time();
int messagesPersistedCount = 0;
UUID destinationUuid = accountsManager.get(key.getAddress()).map(Account::getUuid).orElse(null);
try (Jedis jedis = jedisPool.getWriteResource()) {
while (true) {
jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes());
Set<Tuple> messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE);
for (Tuple message : messages) {
persistMessage(key, destinationUuid, (long)message.getScore(), message.getBinaryElement());
messagesPersistedCount++;
}
if (messages.size() < CHUNK_SIZE) {
jedis.del(key.getUserMessageQueuePersistInProgress());
return;
}
}
} finally {
timer.stop();
queueSizeHistogram.update(messagesPersistedCount);
}
}
private void persistMessage(Key key, UUID destinationUuid, long score, byte[] message) {
try {
Envelope envelope = Envelope.parseFrom(message);
UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null;
envelope = envelope.toBuilder().clearServerGuid().build();
database.store(guid, envelope, key.getAddress(), key.getDeviceId());
} catch (InvalidProtocolBufferException e) {
logger.error("Error parsing envelope", e);
}
remove(key.getAddress(), destinationUuid, key.getDeviceId(), score);
}
private List<byte[]> getQueuesToPersist(GetOperation getOperation) {
Timer.Context timer = getQueuesTimer.time();
try {
long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime);
return getOperation.getQueues(Key.getUserMessageQueueIndex(), maxTime, 100);
} finally {
timer.stop();
}
}
private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key) {
Timer.Context timer = notifyTimer.time();
try {
boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build());
if (!notified) {
Optional<Account> account = accountsManager.get(key.getAddress());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(key.getDeviceId());
if (device.isPresent()) {
try {
pushSender.sendQueuedNotification(account.get(), device.get());
} catch (NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!");
}
}
}
}
} finally {
timer.stop();
}
}
}
}

View File

@ -26,7 +26,7 @@ public class MessagesCacheTest extends AbstractMessagesCacheTest {
final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration());
final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool();
messagesCache = new MessagesCache(jedisPool, mock(Messages.class), mock(AccountsManager.class), 60);
messagesCache = new MessagesCache(jedisPool);
}
@After