diff --git a/pom.xml b/pom.xml index 60786eb6a..3754ea9da 100644 --- a/pom.xml +++ b/pom.xml @@ -78,7 +78,7 @@ redis.clients jedis - 2.7.3 + 2.9.0 jar compile diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 8029374f6..033232713 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -21,6 +21,7 @@ import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; import org.whispersystems.textsecuregcm.configuration.FederationConfiguration; import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; +import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.ProfilesConfiguration; import org.whispersystems.textsecuregcm.configuration.PushConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; @@ -75,6 +76,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private RedisConfiguration directory; + @NotNull + @Valid + @JsonProperty + private MessageCacheConfiguration messageCache; + @Valid @NotNull @JsonProperty @@ -160,6 +166,10 @@ public class WhisperServerConfiguration extends Configuration { return directory; } + public MessageCacheConfiguration getMessageCacheConfiguration() { + return messageCache; + } + public DataSourceFactory getMessageStoreConfiguration() { return messageStore; } diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d89ed7112..4dfb071ae 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2013 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -24,7 +24,6 @@ import com.google.common.base.Optional; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.eclipse.jetty.servlets.CrossOriginFilter; import org.skife.jdbi.v2.DBI; -import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.dropwizard.simpleauth.AuthDynamicFeature; import org.whispersystems.dropwizard.simpleauth.AuthValueFactoryProvider; @@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Messages; +import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PendingAccounts; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; @@ -161,15 +161,17 @@ public class WhisperServerService extends Applicationof(deadLetterHandler)); + DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.of(deadLetterHandler)); PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager); APNSender apnSender = new APNSender(accountsManager, config.getApnConfiguration()); GCMSender gcmSender = new GCMSender(accountsManager, config.getGcmConfiguration().getApiKey()); @@ -186,10 +188,13 @@ public class WhisperServerService extends Application() { - @Override - public Integer getValue() { - return executor.getSize(); - } - }); + (Gauge) executor::getSize); } public void sendMessage(final Account account, final Device device, final Envelope message, final boolean silent) @@ -77,22 +72,17 @@ public class PushSender implements Managed { } if (queueSize > 0) { - executor.execute(new Runnable() { - @Override - public void run() { - sendSynchronousMessage(account, device, message, silent); - } - }); + executor.execute(() -> sendSynchronousMessage(account, device, message, silent)); } else { sendSynchronousMessage(account, device, message, silent); } } - public void sendQueuedNotification(Account account, Device device, int messageQueueDepth, boolean fallback) + public void sendQueuedNotification(Account account, Device device, boolean fallback) throws NotPushRegisteredException, TransientPushFailureException { if (device.getGcmId() != null) sendGcmNotification(account, device); - else if (device.getApnId() != null) sendApnNotification(account, device, messageQueueDepth, fallback); + else if (device.getApnId() != null) sendApnNotification(account, device, fallback); else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!"); } @@ -127,16 +117,15 @@ public class PushSender implements Managed { if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) { boolean fallback = !silent && !outgoingMessage.getSource().equals(account.getNumber()); - sendApnNotification(account, device, deliveryStatus.getMessageQueueDepth(), fallback); + sendApnNotification(account, device, fallback); } } - private void sendApnNotification(Account account, Device device, int messageQueueDepth, boolean fallback) { + private void sendApnNotification(Account account, Device device, boolean fallback) { ApnMessage apnMessage; if (!Util.isEmpty(device.getVoipApnId())) { - apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), - String.format(APN_PAYLOAD, messageQueueDepth), true, + apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION)); if (fallback) { @@ -144,8 +133,7 @@ public class PushSender implements Managed { new ApnFallbackTask(device.getApnId(), device.getVoipApnId(), apnMessage)); } } else { - apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), - String.format(APN_PAYLOAD, messageQueueDepth), + apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, false, ApnMessage.MAX_EXPIRATION); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index 430caaa37..2021c9397 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2014 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -36,12 +36,13 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag public class WebsocketSender { - public static enum Type { + public enum Type { APN, GCM, WEB } + @SuppressWarnings("unused") private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); @@ -79,28 +80,26 @@ public class WebsocketSender { else if (channel == Type.GCM) gcmOnlineMeter.mark(); else websocketOnlineMeter.mark(); - return new DeliveryStatus(true, 0); + return new DeliveryStatus(true); } else { if (channel == Type.APN) apnOfflineMeter.mark(); else if (channel == Type.GCM) gcmOfflineMeter.mark(); else websocketOfflineMeter.mark(); - int queueDepth = queueMessage(account, device, message); - return new DeliveryStatus(false, queueDepth); + queueMessage(account, device, message); + return new DeliveryStatus(false); } } - public int queueMessage(Account account, Device device, Envelope message) { + public void queueMessage(Account account, Device device, Envelope message) { websocketRequeueMeter.mark(); - WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); - int queueDepth = messagesManager.insert(account.getNumber(), device.getId(), message); + WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); + messagesManager.insert(account.getNumber(), device.getId(), message); pubSubManager.publish(address, PubSubMessage.newBuilder() .setType(PubSubMessage.Type.QUERY_DB) .build()); - - return queueDepth; } public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) { @@ -118,22 +117,17 @@ public class WebsocketSender { } } - public static class DeliveryStatus { + static class DeliveryStatus { private final boolean delivered; - private final int messageQueueDepth; - public DeliveryStatus(boolean delivered, int messageQueueDepth) { + DeliveryStatus(boolean delivered) { this.delivered = delivered; - this.messageQueueDepth = messageQueueDepth; } - public boolean isDelivered() { + boolean isDelivered() { return delivered; } - public int getMessageQueueDepth() { - return messageQueueDepth; - } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java b/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java new file mode 100644 index 000000000..643b57dcb --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java @@ -0,0 +1,64 @@ +package org.whispersystems.textsecuregcm.redis; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.List; + +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; +import redis.clients.jedis.exceptions.JedisDataException; + +public class LuaScript { + + private final JedisPool jedisPool; + private final String script; + private final byte[] sha; + + public static LuaScript fromResource(JedisPool jedisPool, String resource) throws IOException { + InputStream inputStream = LuaScript.class.getClassLoader().getResourceAsStream(resource); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + byte[] buffer = new byte[4096]; + int read; + + while ((read = inputStream.read(buffer)) != -1) { + baos.write(buffer, 0, read); + } + + inputStream.close(); + baos.close(); + + return new LuaScript(jedisPool, new String(baos.toByteArray())); + } + + private LuaScript(JedisPool jedisPool, String script) { + this.jedisPool = jedisPool; + this.script = script; + this.sha = storeScript(jedisPool, script).getBytes(); + } + + public Object execute(List keys, List args) { + try (Jedis jedis = jedisPool.getResource()) { + try { + return jedis.evalsha(sha, keys, args); + } catch (JedisDataException e) { + storeScript(jedisPool, script); + return jedis.evalsha(sha, keys, args); + } + } + } + + private String storeScript(JedisPool jedisPool, String script) { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.scriptLoad(script); + } + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java index deb2723c6..c740f44b4 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java @@ -25,7 +25,7 @@ import java.util.List; public abstract class Messages { - public static final int RESULT_SET_CHUNK_SIZE = 100; + static final int RESULT_SET_CHUNK_SIZE = 100; private static final String ID = "id"; private static final String TYPE = "type"; @@ -38,10 +38,9 @@ public abstract class Messages { private static final String MESSAGE = "message"; private static final String CONTENT = "content"; - @SqlQuery("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + - "VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content) " + - "RETURNING (SELECT COUNT(id) FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + TYPE + " != " + Envelope.Type.RECEIPT_VALUE + ")") - abstract int store(@MessageBinder Envelope message, + @SqlUpdate("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + + "VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content)") + abstract void store(@MessageBinder Envelope message, @Bind("destination") String destination, @Bind("destination_device") long destinationDevice); @@ -100,6 +99,7 @@ public abstract class Messages { } return new OutgoingMessageEntity(resultSet.getLong(ID), + false, type, resultSet.getString(RELAY), resultSet.getLong(TIMESTAMP), diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java new file mode 100644 index 000000000..b99636c49 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -0,0 +1,500 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; +import com.codahale.metrics.Timer; +import com.google.common.base.Optional; +import com.google.protobuf.InvalidProtocolBufferException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.push.TransientPushFailureException; +import org.whispersystems.textsecuregcm.redis.LuaScript; +import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.codahale.metrics.MetricRegistry.name; +import io.dropwizard.lifecycle.Managed; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; +import redis.clients.jedis.Tuple; +import redis.clients.util.SafeEncoder; + +public class MessagesCache implements Managed { + + private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class); + + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" )); + private static final Timer removeByIdTimer = metricRegistry.timer(name(MessagesCache.class, "removeById" )); + private static final Timer removeByNameTimer = metricRegistry.timer(name(MessagesCache.class, "removeByName")); + private static final Timer getTimer = metricRegistry.timer(name(MessagesCache.class, "get" )); + private static final Timer clearAccountTimer = metricRegistry.timer(name(MessagesCache.class, "clearAccount")); + private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" )); + + private final JedisPool jedisPool; + private final Messages database; + private final AccountsManager accountsManager; + private final int delayMinutes; + + private InsertOperation insertOperation; + private RemoveOperation removeOperation; + private GetOperation getOperation; + + private PubSubManager pubSubManager; + private PushSender pushSender; + private MessagePersister messagePersister; + + public MessagesCache(JedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) { + this.jedisPool = jedisPool; + this.database = database; + this.accountsManager = accountsManager; + this.delayMinutes = delayMinutes; + } + + public void insert(String destination, long destinationDevice, Envelope message) { + Timer.Context timer = insertTimer.time(); + + try { + insertOperation.insert(destination, destinationDevice, System.currentTimeMillis(), message); + } finally { + timer.stop(); + } + } + + public void remove(String destination, long destinationDevice, long id) { + Timer.Context timer = removeByIdTimer.time(); + + try { + removeOperation.remove(destination, destinationDevice, id); + } finally { + timer.stop(); + } + } + + public Optional remove(String destination, long destinationDevice, String sender, long timestamp) { + Timer.Context timer = removeByNameTimer.time(); + + try { + byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp); + + if (serialized != null) { + Envelope envelope = Envelope.parseFrom(serialized); + return Optional.of(constructEntityFromEnvelope(0, envelope)); + } + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } finally { + timer.stop(); + } + + return Optional.absent(); + } + + public List get(String destination, long destinationDevice, int limit) { + Timer.Context timer = getTimer.time(); + + try { + List results = new LinkedList<>(); + Key key = new Key(destination, destinationDevice); + List> items = getOperation.getItems(key.getUserMessageQueue(), key.getUserMessageQueuePersistInProgress(), limit); + + for (Pair item : items) { + try { + long id = item.second().longValue(); + Envelope message = Envelope.parseFrom(item.first()); + results.add(constructEntityFromEnvelope(id, message)); + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + + return results; + } finally { + timer.stop(); + } + } + + public void clear(String destination) { + Timer.Context timer = clearAccountTimer.time(); + + try { + for (int i = 1; i < 255; i++) { + clear(destination, i); + } + } finally { + timer.stop(); + } + } + + public void clear(String destination, long deviceId) { + Timer.Context timer = clearDeviceTimer.time(); + + try { + removeOperation.clear(destination, deviceId); + } finally { + timer.stop(); + } + } + + public void setPubSubManager(PubSubManager pubSubManager, PushSender pushSender) { + this.pubSubManager = pubSubManager; + this.pushSender = pushSender; + } + + @Override + public void start() throws Exception { + this.insertOperation = new InsertOperation(jedisPool); + this.removeOperation = new RemoveOperation(jedisPool); + this.getOperation = new GetOperation(jedisPool); + this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES); + + this.messagePersister.start(); + } + + @Override + public void stop() throws Exception { + messagePersister.shutdown(); + logger.info("Message persister shut down..."); + } + + private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) { + return new OutgoingMessageEntity(id, true, + envelope.getType().getNumber(), + envelope.getRelay(), + envelope.getTimestamp(), + envelope.getSource(), + envelope.getSourceDevice(), + envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null, + envelope.hasContent() ? envelope.getContent().toByteArray() : null); + } + + private static class Key { + + private final byte[] userMessageQueue; + private final byte[] userMessageQueueMetadata; + private final byte[] userMessageQueuePersistInProgress; + + private final String address; + private final long deviceId; + + Key(String address, long deviceId) { + this.address = address; + this.deviceId = deviceId; + this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes(); + this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes(); + this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes(); + } + + String getAddress() { + return address; + } + + long getDeviceId() { + return deviceId; + } + + byte[] getUserMessageQueue() { + return userMessageQueue; + } + + byte[] getUserMessageQueueMetadata() { + return userMessageQueueMetadata; + } + + byte[] getUserMessageQueuePersistInProgress() { + return userMessageQueuePersistInProgress; + } + + static byte[] getUserMessageQueueIndex() { + return "user_queue_index".getBytes(); + } + + static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException { + try { + String[] parts = new String(userMessageQueue).split("::"); + + if (parts.length != 3) { + throw new IOException("Malformed key: " + new String(userMessageQueue)); + } + + return new Key(parts[1], Long.parseLong(parts[2])); + } catch (NumberFormatException e) { + throw new IOException(e); + } + } + } + + private static class InsertOperation { + private final LuaScript insert; + + InsertOperation(JedisPool jedisPool) throws IOException { + this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua"); + } + + public void insert(String destination, long destinationDevice, long timestamp, Envelope message) { + Key key = new Key(destination, destinationDevice); + String sender = message.getSource() + "::" + message.getTimestamp(); + + List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); + List args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes()); + + insert.execute(keys, args); + } + } + + private static class RemoveOperation { + + private final LuaScript removeById; + private final LuaScript removeBySender; + private final LuaScript removeQueue; + + RemoveOperation(JedisPool jedisPool) throws IOException { + this.removeById = LuaScript.fromResource(jedisPool, "lua/remove_item_by_id.lua" ); + this.removeBySender = LuaScript.fromResource(jedisPool, "lua/remove_item_by_sender.lua"); + this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" ); + } + + public void remove(String destination, long destinationDevice, long id) { + Key key = new Key(destination, destinationDevice); + + List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); + List args = Collections.singletonList(String.valueOf(id).getBytes()); + + this.removeById.execute(keys, args); + } + + public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) { + Key key = new Key(destination, destinationDevice); + String senderKey = sender + "::" + timestamp; + + List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); + List args = Collections.singletonList(senderKey.getBytes()); + + return (byte[])this.removeBySender.execute(keys, args); + } + + public void clear(String destination, long deviceId) { + Key key = new Key(destination, deviceId); + + List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); + List args = new LinkedList<>(); + + this.removeQueue.execute(keys, args); + } + } + + private static class GetOperation { + + private final LuaScript getQueues; + private final LuaScript getItems; + + GetOperation(JedisPool jedisPool) throws IOException { + this.getQueues = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua"); + this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua"); + } + + List getQueues(byte[] queue, long maxTimeMillis, int limit) { + List keys = Collections.singletonList(queue); + List args = Arrays.asList(String.valueOf(maxTimeMillis).getBytes(), String.valueOf(limit).getBytes()); + + return (List)getQueues.execute(keys, args); + } + + List> getItems(byte[] queue, byte[] lock, int limit) { + List keys = Arrays.asList(queue, lock); + List args = Collections.singletonList(String.valueOf(limit).getBytes()); + + Iterator results = ((List) getItems.execute(keys, args)).iterator(); + List> items = new LinkedList<>(); + + while (results.hasNext()) { + items.add(new Pair<>(results.next(), Double.valueOf(SafeEncoder.encode(results.next())))); + } + + return items; + } + } + + private static class MessagePersister extends Thread { + + private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" )); + private static final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue")); + private static final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" )); + private static final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" )); + private static final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount")); + + private static final int CHUNK_SIZE = 100; + + private final AtomicBoolean running = new AtomicBoolean(true); + + private final JedisPool jedisPool; + private final Messages database; + private final long delayTime; + private final TimeUnit delayTimeUnit; + + private final PubSubManager pubSubManager; + private final PushSender pushSender; + private final AccountsManager accountsManager; + + private final GetOperation getOperation; + private final RemoveOperation removeOperation; + + private boolean finished = false; + + MessagePersister(JedisPool jedisPool, + Messages database, + PubSubManager pubSubManager, + PushSender pushSender, + AccountsManager accountsManager, + long delayTime, + TimeUnit delayTimeUnit) + throws IOException + { + super(MessagePersister.class.getSimpleName()); + this.jedisPool = jedisPool; + this.database = database; + + this.pubSubManager = pubSubManager; + this.pushSender = pushSender; + this.accountsManager = accountsManager; + + this.delayTime = delayTime; + this.delayTimeUnit = delayTimeUnit; + this.getOperation = new GetOperation(jedisPool); + this.removeOperation = new RemoveOperation(jedisPool); + } + + @Override + public void run() { + while (running.get()) { + try { + List queuesToPersist = getQueuesToPersist(getOperation); + queueCountHistogram.update(queuesToPersist.size()); + + for (byte[] queue : queuesToPersist) { + Key key = Key.fromUserMessageQueue(queue); + + persistQueue(jedisPool, key); + notifyClients(accountsManager, pubSubManager, pushSender, key); + } + + if (queuesToPersist.isEmpty()) { + Thread.sleep(10000); + } + } catch (Throwable t) { + logger.error("Exception while persisting: ", t); + } + } + + synchronized (this) { + finished = true; + notifyAll(); + } + } + + synchronized void shutdown() { + running.set(false); + while (!finished) Util.wait(this); + } + + private void persistQueue(JedisPool jedisPool, Key key) throws IOException { + Timer.Context timer = persistQueueTimer.time(); + + int messagesPersistedCount = 0; + + try (Jedis jedis = jedisPool.getResource()) { + while (true) { + jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes()); + + Set messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE); + + for (Tuple message : messages) { + persistMessage(key, (long)message.getScore(), message.getBinaryElement()); + messagesPersistedCount++; + } + + if (messages.size() < CHUNK_SIZE) { + jedis.del(key.getUserMessageQueuePersistInProgress()); + return; + } + } + } finally { + timer.stop(); + queueSizeHistogram.update(messagesPersistedCount); + } + } + + private void persistMessage(Key key, long score, byte[] message) { + try { + Envelope envelope = Envelope.parseFrom(message); + database.store(envelope, key.getAddress(), key.getDeviceId()); + } catch (InvalidProtocolBufferException e) { + logger.error("Error parsing envelope", e); + } + + removeOperation.remove(key.getAddress(), key.getDeviceId(), score); + } + + private List getQueuesToPersist(GetOperation getOperation) { + Timer.Context timer = getQueuesTimer.time(); + try { + long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime); + return getOperation.getQueues(Key.getUserMessageQueueIndex(), maxTime, 100); + } finally { + timer.stop(); + } + } + + private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key) + throws IOException + { + Timer.Context timer = notifyTimer.time(); + + try { + boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()), + PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) + .build()); + + if (!notified) { + Optional account = accountsManager.get(key.getAddress()); + + if (account.isPresent()) { + Optional device = account.get().getDevice(key.getDeviceId()); + + if (device.isPresent()) { + try { + pushSender.sendQueuedNotification(account.get(), device.get(), false); + } catch (NotPushRegisteredException e) { + logger.warn("After message persistence, no longer push registered!"); + } catch (TransientPushFailureException e) { + logger.warn("Transient push failure!", e); + } + } + } + } + } finally { + timer.stop(); + } + } + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index fff0213bf..fbafcd232 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -1,44 +1,118 @@ package org.whispersystems.textsecuregcm.storage; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; import com.google.common.base.Optional; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Conversions; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.List; +import static com.codahale.metrics.MetricRegistry.name; + public class MessagesManager { - private final Messages messages; + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Meter cacheHitByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitById" )); + private static final Meter cacheMissByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissById" )); + private static final Meter cacheHitByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByName" )); + private static final Meter cacheMissByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissByName")); - public MessagesManager(Messages messages) { - this.messages = messages; + private final Messages messages; + private final MessagesCache messagesCache; + private final Distribution distribution; + + public MessagesManager(Messages messages, MessagesCache messagesCache, float cacheRate) { + this.messages = messages; + this.messagesCache = messagesCache; + this.distribution = new Distribution(cacheRate); } - public int insert(String destination, long destinationDevice, Envelope message) { - return this.messages.store(message, destination, destinationDevice) + 1; + public void insert(String destination, long destinationDevice, Envelope message) { + if (distribution.isQualified(destination, destinationDevice)) { + messagesCache.insert(destination, destinationDevice, message); + } else { + messages.store(message, destination, destinationDevice); + } } public OutgoingMessageEntityList getMessagesForDevice(String destination, long destinationDevice) { List messages = this.messages.load(destination, destinationDevice); + + if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { + messages.addAll(this.messagesCache.get(destination, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); + } + return new OutgoingMessageEntityList(messages, messages.size() >= Messages.RESULT_SET_CHUNK_SIZE); } public void clear(String destination) { + this.messagesCache.clear(destination); this.messages.clear(destination); } public void clear(String destination, long deviceId) { + this.messagesCache.clear(destination, deviceId); this.messages.clear(destination, deviceId); } public Optional delete(String destination, long destinationDevice, String source, long timestamp) { - return Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp)); + Optional removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp); + + if (!removed.isPresent()) { + removed = Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp)); + cacheMissByNameMeter.mark(); + } else { + cacheHitByNameMeter.mark(); + } + + return removed; } - public void delete(String destination, long id) { - this.messages.remove(destination, id); + public void delete(String destination, long deviceId, long id, boolean cached) { + if (cached) { + this.messagesCache.remove(destination, deviceId, id); + cacheHitByIdMeter.mark(); + } else { + this.messages.remove(destination, id); + cacheMissByIdMeter.mark(); + } } + + public static class Distribution { + + private final float percentage; + + public Distribution(float percentage) { + this.percentage = percentage; + } + + public boolean isQualified(String address, long device) { + if (percentage <= 0) return false; + if (percentage >= 100) return true; + + try { + MessageDigest digest = MessageDigest.getInstance("SHA1"); + digest.update(address.getBytes()); + digest.update(Conversions.longToByteArray(device)); + + byte[] result = digest.digest(); + int hashCode = Conversions.byteArrayToShort(result); + + return hashCode <= 65535 * percentage; + } catch (NoSuchAlgorithmException e) { + throw new AssertionError(e); + } + } + + } + } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 7fb703c2b..723fd11b3 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -82,7 +82,7 @@ public class WebSocketConnection implements DispatchChannel { processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: - sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false); + sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false); break; case PubSubMessage.Type.CONNECTED_VALUE: if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) { @@ -106,9 +106,9 @@ public class WebSocketConnection implements DispatchChannel { processStoredMessages(); } - private void sendMessage(final Envelope message, - final Optional storedMessageId, - final boolean requery) + private void sendMessage(final Envelope message, + final Optional storedMessageInfo, + final boolean requery) { try { EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, device.getSignalingKey()); @@ -125,17 +125,17 @@ public class WebSocketConnection implements DispatchChannel { } if (isSuccessResponse(response)) { - if (storedMessageId.isPresent()) messagesManager.delete(account.getNumber(), storedMessageId.get()); - if (!isReceipt) sendDeliveryReceiptFor(message); - if (requery) processStoredMessages(); - } else if (!isSuccessResponse(response) && !storedMessageId.isPresent()) { + if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); + if (!isReceipt) sendDeliveryReceiptFor(message); + if (requery) processStoredMessages(); + } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) { requeueMessage(message); } } @Override public void onFailure(@Nonnull Throwable throwable) { - if (!storedMessageId.isPresent()) requeueMessage(message); + if (!storedMessageInfo.isPresent()) requeueMessage(message); } private boolean isSuccessResponse(WebSocketResponseMessage response) { @@ -148,11 +148,12 @@ public class WebSocketConnection implements DispatchChannel { } private void requeueMessage(Envelope message) { - int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message); - boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT; + pushSender.getWebSocketSender().queueMessage(account, device, message); + + boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT; try { - pushSender.sendQueuedNotification(account, device, queueDepth, fallback); + pushSender.sendQueuedNotification(account, device, fallback); } catch (NotPushRegisteredException | TransientPushFailureException e) { logger.warn("requeueMessage", e); } @@ -162,7 +163,7 @@ public class WebSocketConnection implements DispatchChannel { try { receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp(), message.hasRelay() ? Optional.of(message.getRelay()) : - Optional.absent()); + Optional.absent()); } catch (NoSuchUserException | NotPushRegisteredException e) { logger.info("No longer registered " + e.getMessage()); } catch(IOException | TransientPushFailureException e) { @@ -196,11 +197,21 @@ public class WebSocketConnection implements DispatchChannel { builder.setRelay(message.getRelay()); } - sendMessage(builder.build(), Optional.of(message.getId()), !iterator.hasNext() && messages.hasMore()); + sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore()); } if (!messages.hasMore()) { client.sendRequest("PUT", "/api/v1/queue/empty", null, Optional.absent()); } } + + private static class StoredMessageInfo { + private final long id; + private final boolean cached; + + private StoredMessageInfo(long id, boolean cached) { + this.id = id; + this.cached = cached; + } + } } diff --git a/src/main/resources/lua/get_items.lua b/src/main/resources/lua/get_items.lua new file mode 100644 index 000000000..7a966d039 --- /dev/null +++ b/src/main/resources/lua/get_items.lua @@ -0,0 +1,10 @@ +-- keys: queue_key, queue_locked_key +-- argv: limit + +local locked = redis.call("GET", KEYS[2]) + +if locked then + return {} +end + +return redis.call("ZRANGE", KEYS[1], 0, ARGV[1], "WITHSCORES") \ No newline at end of file diff --git a/src/main/resources/lua/get_queues_to_persist.lua b/src/main/resources/lua/get_queues_to_persist.lua new file mode 100644 index 000000000..b4afdaf43 --- /dev/null +++ b/src/main/resources/lua/get_queues_to_persist.lua @@ -0,0 +1,10 @@ +-- keys: queue_total_index +-- argv: max_time, limit + +local results = redis.call("ZRANGEBYSCORE", KEYS[1], 0, ARGV[1], "LIMIT", 0, ARGV[2]) + +if results and next(results) then + redis.call("ZREM", KEYS[1], unpack(results)) +end + +return results diff --git a/src/main/resources/lua/insert_item.lua b/src/main/resources/lua/insert_item.lua new file mode 100644 index 000000000..4f36a69df --- /dev/null +++ b/src/main/resources/lua/insert_item.lua @@ -0,0 +1,13 @@ +-- keys: queue_key, queue_metadata_key, queue_total_index +-- argv: message, current_time, sender_key + +local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1) +redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1]) +redis.call("HSET", KEYS[2], ARGV[3], messageId) +redis.call("HSET", KEYS[2], messageId, ARGV[3]) + +redis.call("EXPIRE", KEYS[1], 7776000) +redis.call("EXPIRE", KEYS[2], 7776000) + +redis.call("ZADD", KEYS[3], "NX", ARGV[2], KEYS[1]) +return messageId \ No newline at end of file diff --git a/src/main/resources/lua/remove_item_by_id.lua b/src/main/resources/lua/remove_item_by_id.lua new file mode 100644 index 000000000..b5e8496ba --- /dev/null +++ b/src/main/resources/lua/remove_item_by_id.lua @@ -0,0 +1,16 @@ +-- keys: queue_key, queue_metadata_key, queue_index +-- argv: index_to_remove + +local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1]) +local senderIndex = redis.call("HGET", KEYS[2], ARGV[1]) + +if senderIndex then + redis.call("HDEL", KEYS[2], senderIndex) + redis.call("HDEL", KEYS[2], ARGV[1]) +end + +if (redis.call("ZCARD", KEYS[1]) == 0) then + redis.call("ZREM", KEYS[3], KEYS[1]) +end + +return removedCount > 0 diff --git a/src/main/resources/lua/remove_item_by_sender.lua b/src/main/resources/lua/remove_item_by_sender.lua new file mode 100644 index 000000000..9aaf7639f --- /dev/null +++ b/src/main/resources/lua/remove_item_by_sender.lua @@ -0,0 +1,22 @@ +-- keys: queue_key, queue_metadata_key, queue_index +-- argv: sender_to_remove + +local messageId = redis.call("HGET", KEYS[2], ARGV[1]) + +if messageId then + local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], messageId, messageId, "LIMIT", 0, 1) + + redis.call("ZREMRANGEBYSCORE", KEYS[1], messageId, messageId) + redis.call("HDEL", KEYS[2], ARGV[1]) + redis.call("HDEL", KEYS[2], messageId) + + if (redis.call("ZCARD", KEYS[1]) == 0) then + redis.call("ZREM", KEYS[3], KEYS[1]) + end + + if envelope and next(envelope) then + return envelope[1] + end +end + +return nil diff --git a/src/main/resources/lua/remove_queue.lua b/src/main/resources/lua/remove_queue.lua new file mode 100644 index 000000000..315a0fa7e --- /dev/null +++ b/src/main/resources/lua/remove_queue.lua @@ -0,0 +1,5 @@ +-- keys: queue_key, queue_metadata_key, queue_index + +redis.call("DEL", KEYS[1]) +redis.call("DEL", KEYS[2]) +redis.call("ZREM", KEYS[3], KEYS[1]) diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index 809d47272..efa222821 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -186,8 +186,8 @@ public class MessageControllerTest { final long timestampTwo = 313388; List messages = new LinkedList() {{ - add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null)); - add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null)); + add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null)); + add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null)); }}; OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); @@ -217,8 +217,8 @@ public class MessageControllerTest { final long timestampTwo = 313388; List messages = new LinkedList() {{ - add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null)); - add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null)); + add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null)); + add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null)); }}; OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); @@ -239,13 +239,13 @@ public class MessageControllerTest { public synchronized void testDeleteMessages() throws Exception { long timestamp = System.currentTimeMillis(); when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337)) - .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, + .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, "+14152222222", 1, "hi".getBytes(), null))); when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338)) - .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, + .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, Envelope.Type.RECEIPT_VALUE, null, System.currentTimeMillis(), "+14152222222", 1, null, null))); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 949e5b1e5..2e10578c3 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -100,9 +100,9 @@ public class WebSocketConnectionTest { MessagesManager storedMessages = mock(MessagesManager.class); List outgoingMessages = new LinkedList () {{ - add(createMessage(1L, "sender1", 1111, false, "first")); - add(createMessage(2L, "sender1", 2222, false, "second")); - add(createMessage(3L, "sender2", 3333, false, "third")); + add(createMessage(1L, false, "sender1", 1111, false, "first")); + add(createMessage(2L, false, "sender1", 2222, false, "second")); + add(createMessage(3L, false, "sender2", 3333, false, "third")); }}; OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); @@ -157,7 +157,7 @@ public class WebSocketConnectionTest { futures.get(0).setException(new IOException()); futures.get(2).setException(new IOException()); - verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L)); + verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L), eq(2L), eq(false)); verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L), eq(Optional.absent())); connection.onDispatchUnsubscribed(websocketAddress.serialize()); @@ -170,7 +170,6 @@ public class WebSocketConnectionTest { WebsocketSender websocketSender = mock(WebsocketSender.class); when(pushSender.getWebSocketSender()).thenReturn(websocketSender); - when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10); Envelope firstMessage = Envelope.newBuilder() .setLegacyMessage(ByteString.copyFrom("first".getBytes())) @@ -251,7 +250,7 @@ public class WebSocketConnectionTest { verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.absent())); verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class)); - verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(10), eq(true)); + verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(true)); connection.onDispatchUnsubscribed(websocketAddress.serialize()); verify(client).close(anyInt(), anyString()); @@ -266,7 +265,6 @@ public class WebSocketConnectionTest { reset(pushSender); when(pushSender.getWebSocketSender()).thenReturn(websocketSender); - when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10); final Envelope firstMessage = Envelope.newBuilder() .setLegacyMessage(ByteString.copyFrom("first".getBytes())) @@ -285,11 +283,11 @@ public class WebSocketConnectionTest { .build(); List pendingMessages = new LinkedList() {{ - add(new OutgoingMessageEntity(1, firstMessage.getType().getNumber(), firstMessage.getRelay(), + add(new OutgoingMessageEntity(1, true, firstMessage.getType().getNumber(), firstMessage.getRelay(), firstMessage.getTimestamp(), firstMessage.getSource(), firstMessage.getSourceDevice(), firstMessage.getLegacyMessage().toByteArray(), firstMessage.getContent().toByteArray())); - add(new OutgoingMessageEntity(2, secondMessage.getType().getNumber(), secondMessage.getRelay(), + add(new OutgoingMessageEntity(2, false, secondMessage.getType().getNumber(), secondMessage.getRelay(), secondMessage.getTimestamp(), secondMessage.getSource(), secondMessage.getSourceDevice(), secondMessage.getLegacyMessage().toByteArray(), secondMessage.getContent().toByteArray())); @@ -355,8 +353,8 @@ public class WebSocketConnectionTest { } - private OutgoingMessageEntity createMessage(long id, String sender, long timestamp, boolean receipt, String content) { - return new OutgoingMessageEntity(id, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, long timestamp, boolean receipt, String content) { + return new OutgoingMessageEntity(id, cached, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, 1, content.getBytes(), null); }