From 5fad8f74b1402fe76d2d3df494c2f1149220e454 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 6 Aug 2020 17:38:06 -0400 Subject: [PATCH] Factor MessagePersister into its own class. --- .../textsecuregcm/WhisperServerService.java | 8 +- .../textsecuregcm/storage/Key.java | 59 ++++ .../storage/MessagePersister.java | 214 ++++++++++++++ .../textsecuregcm/storage/MessagesCache.java | 277 +----------------- .../storage/MessagesCacheTest.java | 2 +- 5 files changed, 280 insertions(+), 280 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/Key.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 6121319a6..4e130c1f2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -128,6 +128,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryReconciler; import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.Messages; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -344,7 +345,7 @@ public class WhisperServerService extends Application queuesToPersist = getQueuesToPersist(); + queueCountHistogram.update(queuesToPersist.size()); + + for (byte[] queue : queuesToPersist) { + Key key = Key.fromUserMessageQueue(queue); + + persistQueue(jedisPool, key); + notifyClients(accountsManager, pubSubManager, pushSender, key); + } + + if (queuesToPersist.isEmpty()) { + //noinspection BusyWait + Thread.sleep(10_000); + } + } catch (Throwable t) { + logger.error("Exception while persisting: ", t); + } + } + + synchronized (this) { + finished = true; + notifyAll(); + } + } + + @Override + public synchronized void stop() { + running.set(false); + while (!finished) Util.wait(this); + + logger.info("Message persister shut down..."); + } + + private void persistQueue(ReplicatedJedisPool jedisPool, Key key) { + Timer.Context timer = persistQueueTimer.time(); + + int messagesPersistedCount = 0; + + UUID destinationUuid = accountsManager.get(key.getAddress()).map(Account::getUuid).orElse(null); + + try (Jedis jedis = jedisPool.getWriteResource()) { + while (true) { + jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes()); + + Set messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE); + + for (Tuple message : messages) { + persistMessage(key, destinationUuid, (long)message.getScore(), message.getBinaryElement()); + messagesPersistedCount++; + } + + if (messages.size() < CHUNK_SIZE) { + jedis.del(key.getUserMessageQueuePersistInProgress()); + return; + } + } + } finally { + timer.stop(); + queueSizeHistogram.update(messagesPersistedCount); + } + } + + private void persistMessage(Key key, UUID destinationUuid, long score, byte[] message) { + try { + MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(message); + UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null; + + envelope = envelope.toBuilder().clearServerGuid().build(); + + database.store(guid, envelope, key.getAddress(), key.getDeviceId()); + } catch (InvalidProtocolBufferException e) { + logger.error("Error parsing envelope", e); + } + + messagesCache.remove(key.getAddress(), destinationUuid, key.getDeviceId(), score); + } + + private List getQueuesToPersist() { + Timer.Context timer = getQueuesTimer.time(); + try { + long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime); + List keys = Collections.singletonList(Key.getUserMessageQueueIndex()); + List args = Arrays.asList(String.valueOf(maxTime).getBytes(), String.valueOf(100).getBytes()); + + //noinspection unchecked + return (List)getQueuesScript.execute(keys, args); + } finally { + timer.stop(); + } + } + + private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key) { + Timer.Context timer = notifyTimer.time(); + + try { + boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()), + PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) + .build()); + + if (!notified) { + Optional account = accountsManager.get(key.getAddress()); + + if (account.isPresent()) { + Optional device = account.get().getDevice(key.getDeviceId()); + + if (device.isPresent()) { + try { + pushSender.sendQueuedNotification(account.get(), device.get()); + } catch (NotPushRegisteredException e) { + logger.warn("After message persistence, no longer push registered!"); + } + } + } + } + } finally { + timer.stop(); + } + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 1dfe8acad..e22667712 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -1,26 +1,18 @@ package org.whispersystems.textsecuregcm.storage; -import com.codahale.metrics.Histogram; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; import com.google.protobuf.InvalidProtocolBufferException; -import io.dropwizard.lifecycle.Managed; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.experiment.Experiment; -import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; -import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.redis.LuaScript; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Pair; -import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import redis.clients.jedis.Jedis; -import redis.clients.jedis.Tuple; import redis.clients.util.SafeEncoder; import java.io.IOException; @@ -30,17 +22,11 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Optional; -import java.util.Set; import java.util.UUID; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import static com.codahale.metrics.MetricRegistry.name; -public class MessagesCache implements Managed, UserMessagesCache { +public class MessagesCache implements UserMessagesCache { private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class); @@ -54,23 +40,13 @@ public class MessagesCache implements Managed, UserMessagesCache { private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" )); private final ReplicatedJedisPool jedisPool; - private final Messages database; - private final AccountsManager accountsManager; - private final int delayMinutes; private final InsertOperation insertOperation; private final RemoveOperation removeOperation; private final GetOperation getOperation; - private PubSubManager pubSubManager; - private PushSender pushSender; - private MessagePersister messagePersister; - - public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) throws IOException { + public MessagesCache(ReplicatedJedisPool jedisPool) throws IOException { this.jedisPool = jedisPool; - this.database = database; - this.accountsManager = accountsManager; - this.delayMinutes = delayMinutes; this.insertOperation = new InsertOperation(jedisPool); this.removeOperation = new RemoveOperation(jedisPool); @@ -198,79 +174,6 @@ public class MessagesCache implements Managed, UserMessagesCache { } } - public void setPubSubManager(PubSubManager pubSubManager, PushSender pushSender) { - this.pubSubManager = pubSubManager; - this.pushSender = pushSender; - } - - @Override - public void start() throws Exception { - this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES); - this.messagePersister.start(); - } - - @Override - public void stop() throws Exception { - messagePersister.shutdown(); - logger.info("Message persister shut down..."); - } - - private static class Key { - - private final byte[] userMessageQueue; - private final byte[] userMessageQueueMetadata; - private final byte[] userMessageQueuePersistInProgress; - - private final String address; - private final long deviceId; - - Key(String address, long deviceId) { - this.address = address; - this.deviceId = deviceId; - this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes(); - this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes(); - this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes(); - } - - String getAddress() { - return address; - } - - long getDeviceId() { - return deviceId; - } - - byte[] getUserMessageQueue() { - return userMessageQueue; - } - - byte[] getUserMessageQueueMetadata() { - return userMessageQueueMetadata; - } - - byte[] getUserMessageQueuePersistInProgress() { - return userMessageQueuePersistInProgress; - } - - static byte[] getUserMessageQueueIndex() { - return "user_queue_index".getBytes(); - } - - static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException { - try { - String[] parts = new String(userMessageQueue).split("::"); - - if (parts.length != 3) { - throw new IOException("Malformed key: " + new String(userMessageQueue)); - } - - return new Key(parts[1], Long.parseLong(parts[2])); - } catch (NumberFormatException e) { - throw new IOException(e); - } - } - } - private static class InsertOperation { private final LuaScript insert; @@ -343,21 +246,12 @@ public class MessagesCache implements Managed, UserMessagesCache { private static class GetOperation { - private final LuaScript getQueues; private final LuaScript getItems; GetOperation(ReplicatedJedisPool jedisPool) throws IOException { - this.getQueues = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua"); this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua"); } - List getQueues(byte[] queue, long maxTimeMillis, int limit) { - List keys = Collections.singletonList(queue); - List args = Arrays.asList(String.valueOf(maxTimeMillis).getBytes(), String.valueOf(limit).getBytes()); - - return (List)getQueues.execute(keys, args); - } - List> getItems(byte[] queue, byte[] lock, int limit) { List keys = Arrays.asList(queue, lock); List args = Collections.singletonList(String.valueOf(limit).getBytes()); @@ -373,171 +267,4 @@ public class MessagesCache implements Managed, UserMessagesCache { } } - private class MessagePersister extends Thread { - - private final Logger logger = LoggerFactory.getLogger(MessagePersister.class); - private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" )); - private final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue")); - private final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" )); - private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" )); - private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount")); - - private static final int CHUNK_SIZE = 100; - - private final AtomicBoolean running = new AtomicBoolean(true); - - private final ReplicatedJedisPool jedisPool; - private final Messages database; - private final long delayTime; - private final TimeUnit delayTimeUnit; - - private final PubSubManager pubSubManager; - private final PushSender pushSender; - private final AccountsManager accountsManager; - - private final GetOperation getOperation; - private final RemoveOperation removeOperation; - - private boolean finished = false; - - MessagePersister(ReplicatedJedisPool jedisPool, - Messages database, - PubSubManager pubSubManager, - PushSender pushSender, - AccountsManager accountsManager, - long delayTime, - TimeUnit delayTimeUnit) - throws IOException - { - super(MessagePersister.class.getSimpleName()); - this.jedisPool = jedisPool; - this.database = database; - - this.pubSubManager = pubSubManager; - this.pushSender = pushSender; - this.accountsManager = accountsManager; - - this.delayTime = delayTime; - this.delayTimeUnit = delayTimeUnit; - this.getOperation = new GetOperation(jedisPool); - this.removeOperation = new RemoveOperation(jedisPool); - } - - @Override - public void run() { - while (running.get()) { - try { - List queuesToPersist = getQueuesToPersist(getOperation); - queueCountHistogram.update(queuesToPersist.size()); - - for (byte[] queue : queuesToPersist) { - Key key = Key.fromUserMessageQueue(queue); - - persistQueue(jedisPool, key); - notifyClients(accountsManager, pubSubManager, pushSender, key); - } - - if (queuesToPersist.isEmpty()) { - Thread.sleep(10000); - } - } catch (Throwable t) { - logger.error("Exception while persisting: ", t); - } - } - - synchronized (this) { - finished = true; - notifyAll(); - } - } - - synchronized void shutdown() { - running.set(false); - while (!finished) Util.wait(this); - } - - private void persistQueue(ReplicatedJedisPool jedisPool, Key key) throws IOException { - Timer.Context timer = persistQueueTimer.time(); - - int messagesPersistedCount = 0; - - UUID destinationUuid = accountsManager.get(key.getAddress()).map(Account::getUuid).orElse(null); - - try (Jedis jedis = jedisPool.getWriteResource()) { - while (true) { - jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes()); - - Set messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE); - - for (Tuple message : messages) { - persistMessage(key, destinationUuid, (long)message.getScore(), message.getBinaryElement()); - messagesPersistedCount++; - } - - if (messages.size() < CHUNK_SIZE) { - jedis.del(key.getUserMessageQueuePersistInProgress()); - return; - } - } - } finally { - timer.stop(); - queueSizeHistogram.update(messagesPersistedCount); - } - } - - private void persistMessage(Key key, UUID destinationUuid, long score, byte[] message) { - try { - Envelope envelope = Envelope.parseFrom(message); - UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null; - - envelope = envelope.toBuilder().clearServerGuid().build(); - - database.store(guid, envelope, key.getAddress(), key.getDeviceId()); - } catch (InvalidProtocolBufferException e) { - logger.error("Error parsing envelope", e); - } - - remove(key.getAddress(), destinationUuid, key.getDeviceId(), score); - } - - private List getQueuesToPersist(GetOperation getOperation) { - Timer.Context timer = getQueuesTimer.time(); - try { - long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime); - return getOperation.getQueues(Key.getUserMessageQueueIndex(), maxTime, 100); - } finally { - timer.stop(); - } - } - - private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key) { - Timer.Context timer = notifyTimer.time(); - - try { - boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()), - PubSubProtos.PubSubMessage.newBuilder() - .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) - .build()); - - if (!notified) { - Optional account = accountsManager.get(key.getAddress()); - - if (account.isPresent()) { - Optional device = account.get().getDevice(key.getDeviceId()); - - if (device.isPresent()) { - try { - pushSender.sendQueuedNotification(account.get(), device.get()); - } catch (NotPushRegisteredException e) { - logger.warn("After message persistence, no longer push registered!"); - } - } - } - } - } finally { - timer.stop(); - } - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 23b813019..00ce1d333 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -26,7 +26,7 @@ public class MessagesCacheTest extends AbstractMessagesCacheTest { final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()); final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool(); - messagesCache = new MessagesCache(jedisPool, mock(Messages.class), mock(AccountsManager.class), 60); + messagesCache = new MessagesCache(jedisPool); } @After