diff --git a/pom.xml b/pom.xml
index 60786eb6a..3754ea9da 100644
--- a/pom.xml
+++ b/pom.xml
@@ -78,7 +78,7 @@
redis.clients
jedis
- 2.7.3
+ 2.9.0
jar
compile
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
index 8029374f6..033232713 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
@@ -21,6 +21,7 @@ import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.FederationConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
+import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.ProfilesConfiguration;
import org.whispersystems.textsecuregcm.configuration.PushConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
@@ -75,6 +76,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private RedisConfiguration directory;
+ @NotNull
+ @Valid
+ @JsonProperty
+ private MessageCacheConfiguration messageCache;
+
@Valid
@NotNull
@JsonProperty
@@ -160,6 +166,10 @@ public class WhisperServerConfiguration extends Configuration {
return directory;
}
+ public MessageCacheConfiguration getMessageCacheConfiguration() {
+ return messageCache;
+ }
+
public DataSourceFactory getMessageStoreConfiguration() {
return messageStore;
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index d89ed7112..4dfb071ae 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -1,4 +1,4 @@
-/**
+/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@@ -24,7 +24,6 @@ import com.google.common.base.Optional;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
-import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.dropwizard.simpleauth.AuthDynamicFeature;
import org.whispersystems.dropwizard.simpleauth.AuthValueFactoryProvider;
@@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.Messages;
+import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
@@ -161,15 +161,17 @@ public class WhisperServerService extends Applicationof(deadLetterHandler));
+ DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.of(deadLetterHandler));
PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
APNSender apnSender = new APNSender(accountsManager, config.getApnConfiguration());
GCMSender gcmSender = new GCMSender(accountsManager, config.getGcmConfiguration().getApiKey());
@@ -186,10 +188,13 @@ public class WhisperServerService extends Application() {
- @Override
- public Integer getValue() {
- return executor.getSize();
- }
- });
+ (Gauge) executor::getSize);
}
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean silent)
@@ -77,22 +72,17 @@ public class PushSender implements Managed {
}
if (queueSize > 0) {
- executor.execute(new Runnable() {
- @Override
- public void run() {
- sendSynchronousMessage(account, device, message, silent);
- }
- });
+ executor.execute(() -> sendSynchronousMessage(account, device, message, silent));
} else {
sendSynchronousMessage(account, device, message, silent);
}
}
- public void sendQueuedNotification(Account account, Device device, int messageQueueDepth, boolean fallback)
+ public void sendQueuedNotification(Account account, Device device, boolean fallback)
throws NotPushRegisteredException, TransientPushFailureException
{
if (device.getGcmId() != null) sendGcmNotification(account, device);
- else if (device.getApnId() != null) sendApnNotification(account, device, messageQueueDepth, fallback);
+ else if (device.getApnId() != null) sendApnNotification(account, device, fallback);
else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!");
}
@@ -127,16 +117,15 @@ public class PushSender implements Managed {
if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) {
boolean fallback = !silent && !outgoingMessage.getSource().equals(account.getNumber());
- sendApnNotification(account, device, deliveryStatus.getMessageQueueDepth(), fallback);
+ sendApnNotification(account, device, fallback);
}
}
- private void sendApnNotification(Account account, Device device, int messageQueueDepth, boolean fallback) {
+ private void sendApnNotification(Account account, Device device, boolean fallback) {
ApnMessage apnMessage;
if (!Util.isEmpty(device.getVoipApnId())) {
- apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(),
- String.format(APN_PAYLOAD, messageQueueDepth), true,
+ apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, true,
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION));
if (fallback) {
@@ -144,8 +133,7 @@ public class PushSender implements Managed {
new ApnFallbackTask(device.getApnId(), device.getVoipApnId(), apnMessage));
}
} else {
- apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(),
- String.format(APN_PAYLOAD, messageQueueDepth),
+ apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD,
false, ApnMessage.MAX_EXPIRATION);
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
index 430caaa37..2021c9397 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
@@ -1,4 +1,4 @@
-/**
+/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@@ -36,12 +36,13 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag
public class WebsocketSender {
- public static enum Type {
+ public enum Type {
APN,
GCM,
WEB
}
+ @SuppressWarnings("unused")
private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
@@ -79,28 +80,26 @@ public class WebsocketSender {
else if (channel == Type.GCM) gcmOnlineMeter.mark();
else websocketOnlineMeter.mark();
- return new DeliveryStatus(true, 0);
+ return new DeliveryStatus(true);
} else {
if (channel == Type.APN) apnOfflineMeter.mark();
else if (channel == Type.GCM) gcmOfflineMeter.mark();
else websocketOfflineMeter.mark();
- int queueDepth = queueMessage(account, device, message);
- return new DeliveryStatus(false, queueDepth);
+ queueMessage(account, device, message);
+ return new DeliveryStatus(false);
}
}
- public int queueMessage(Account account, Device device, Envelope message) {
+ public void queueMessage(Account account, Device device, Envelope message) {
websocketRequeueMeter.mark();
- WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
- int queueDepth = messagesManager.insert(account.getNumber(), device.getId(), message);
+ WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
+ messagesManager.insert(account.getNumber(), device.getId(), message);
pubSubManager.publish(address, PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.QUERY_DB)
.build());
-
- return queueDepth;
}
public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) {
@@ -118,22 +117,17 @@ public class WebsocketSender {
}
}
- public static class DeliveryStatus {
+ static class DeliveryStatus {
private final boolean delivered;
- private final int messageQueueDepth;
- public DeliveryStatus(boolean delivered, int messageQueueDepth) {
+ DeliveryStatus(boolean delivered) {
this.delivered = delivered;
- this.messageQueueDepth = messageQueueDepth;
}
- public boolean isDelivered() {
+ boolean isDelivered() {
return delivered;
}
- public int getMessageQueueDepth() {
- return messageQueueDepth;
- }
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java b/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java
new file mode 100644
index 000000000..643b57dcb
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/redis/LuaScript.java
@@ -0,0 +1,64 @@
+package org.whispersystems.textsecuregcm.redis;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.util.List;
+
+import redis.clients.jedis.Jedis;
+import redis.clients.jedis.JedisPool;
+import redis.clients.jedis.exceptions.JedisDataException;
+
+public class LuaScript {
+
+ private final JedisPool jedisPool;
+ private final String script;
+ private final byte[] sha;
+
+ public static LuaScript fromResource(JedisPool jedisPool, String resource) throws IOException {
+ InputStream inputStream = LuaScript.class.getClassLoader().getResourceAsStream(resource);
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+
+ byte[] buffer = new byte[4096];
+ int read;
+
+ while ((read = inputStream.read(buffer)) != -1) {
+ baos.write(buffer, 0, read);
+ }
+
+ inputStream.close();
+ baos.close();
+
+ return new LuaScript(jedisPool, new String(baos.toByteArray()));
+ }
+
+ private LuaScript(JedisPool jedisPool, String script) {
+ this.jedisPool = jedisPool;
+ this.script = script;
+ this.sha = storeScript(jedisPool, script).getBytes();
+ }
+
+ public Object execute(List keys, List args) {
+ try (Jedis jedis = jedisPool.getResource()) {
+ try {
+ return jedis.evalsha(sha, keys, args);
+ } catch (JedisDataException e) {
+ storeScript(jedisPool, script);
+ return jedis.evalsha(sha, keys, args);
+ }
+ }
+ }
+
+ private String storeScript(JedisPool jedisPool, String script) {
+ try (Jedis jedis = jedisPool.getResource()) {
+ return jedis.scriptLoad(script);
+ }
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java
index deb2723c6..c740f44b4 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java
@@ -25,7 +25,7 @@ import java.util.List;
public abstract class Messages {
- public static final int RESULT_SET_CHUNK_SIZE = 100;
+ static final int RESULT_SET_CHUNK_SIZE = 100;
private static final String ID = "id";
private static final String TYPE = "type";
@@ -38,10 +38,9 @@ public abstract class Messages {
private static final String MESSAGE = "message";
private static final String CONTENT = "content";
- @SqlQuery("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
- "VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content) " +
- "RETURNING (SELECT COUNT(id) FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + TYPE + " != " + Envelope.Type.RECEIPT_VALUE + ")")
- abstract int store(@MessageBinder Envelope message,
+ @SqlUpdate("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
+ "VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content)")
+ abstract void store(@MessageBinder Envelope message,
@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
@@ -100,6 +99,7 @@ public abstract class Messages {
}
return new OutgoingMessageEntity(resultSet.getLong(ID),
+ false,
type,
resultSet.getString(RELAY),
resultSet.getLong(TIMESTAMP),
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java
new file mode 100644
index 000000000..b99636c49
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java
@@ -0,0 +1,500 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.codahale.metrics.Histogram;
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.SharedMetricRegistries;
+import com.codahale.metrics.Timer;
+import com.google.common.base.Optional;
+import com.google.protobuf.InvalidProtocolBufferException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
+import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
+import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
+import org.whispersystems.textsecuregcm.push.PushSender;
+import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
+import org.whispersystems.textsecuregcm.redis.LuaScript;
+import org.whispersystems.textsecuregcm.util.Constants;
+import org.whispersystems.textsecuregcm.util.Pair;
+import org.whispersystems.textsecuregcm.util.Util;
+import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static com.codahale.metrics.MetricRegistry.name;
+import io.dropwizard.lifecycle.Managed;
+import redis.clients.jedis.Jedis;
+import redis.clients.jedis.JedisPool;
+import redis.clients.jedis.Tuple;
+import redis.clients.util.SafeEncoder;
+
+public class MessagesCache implements Managed {
+
+ private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
+
+ private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" ));
+ private static final Timer removeByIdTimer = metricRegistry.timer(name(MessagesCache.class, "removeById" ));
+ private static final Timer removeByNameTimer = metricRegistry.timer(name(MessagesCache.class, "removeByName"));
+ private static final Timer getTimer = metricRegistry.timer(name(MessagesCache.class, "get" ));
+ private static final Timer clearAccountTimer = metricRegistry.timer(name(MessagesCache.class, "clearAccount"));
+ private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" ));
+
+ private final JedisPool jedisPool;
+ private final Messages database;
+ private final AccountsManager accountsManager;
+ private final int delayMinutes;
+
+ private InsertOperation insertOperation;
+ private RemoveOperation removeOperation;
+ private GetOperation getOperation;
+
+ private PubSubManager pubSubManager;
+ private PushSender pushSender;
+ private MessagePersister messagePersister;
+
+ public MessagesCache(JedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) {
+ this.jedisPool = jedisPool;
+ this.database = database;
+ this.accountsManager = accountsManager;
+ this.delayMinutes = delayMinutes;
+ }
+
+ public void insert(String destination, long destinationDevice, Envelope message) {
+ Timer.Context timer = insertTimer.time();
+
+ try {
+ insertOperation.insert(destination, destinationDevice, System.currentTimeMillis(), message);
+ } finally {
+ timer.stop();
+ }
+ }
+
+ public void remove(String destination, long destinationDevice, long id) {
+ Timer.Context timer = removeByIdTimer.time();
+
+ try {
+ removeOperation.remove(destination, destinationDevice, id);
+ } finally {
+ timer.stop();
+ }
+ }
+
+ public Optional remove(String destination, long destinationDevice, String sender, long timestamp) {
+ Timer.Context timer = removeByNameTimer.time();
+
+ try {
+ byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
+
+ if (serialized != null) {
+ Envelope envelope = Envelope.parseFrom(serialized);
+ return Optional.of(constructEntityFromEnvelope(0, envelope));
+ }
+ } catch (InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ } finally {
+ timer.stop();
+ }
+
+ return Optional.absent();
+ }
+
+ public List get(String destination, long destinationDevice, int limit) {
+ Timer.Context timer = getTimer.time();
+
+ try {
+ List results = new LinkedList<>();
+ Key key = new Key(destination, destinationDevice);
+ List> items = getOperation.getItems(key.getUserMessageQueue(), key.getUserMessageQueuePersistInProgress(), limit);
+
+ for (Pair item : items) {
+ try {
+ long id = item.second().longValue();
+ Envelope message = Envelope.parseFrom(item.first());
+ results.add(constructEntityFromEnvelope(id, message));
+ } catch (InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ }
+ }
+
+ return results;
+ } finally {
+ timer.stop();
+ }
+ }
+
+ public void clear(String destination) {
+ Timer.Context timer = clearAccountTimer.time();
+
+ try {
+ for (int i = 1; i < 255; i++) {
+ clear(destination, i);
+ }
+ } finally {
+ timer.stop();
+ }
+ }
+
+ public void clear(String destination, long deviceId) {
+ Timer.Context timer = clearDeviceTimer.time();
+
+ try {
+ removeOperation.clear(destination, deviceId);
+ } finally {
+ timer.stop();
+ }
+ }
+
+ public void setPubSubManager(PubSubManager pubSubManager, PushSender pushSender) {
+ this.pubSubManager = pubSubManager;
+ this.pushSender = pushSender;
+ }
+
+ @Override
+ public void start() throws Exception {
+ this.insertOperation = new InsertOperation(jedisPool);
+ this.removeOperation = new RemoveOperation(jedisPool);
+ this.getOperation = new GetOperation(jedisPool);
+ this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
+
+ this.messagePersister.start();
+ }
+
+ @Override
+ public void stop() throws Exception {
+ messagePersister.shutdown();
+ logger.info("Message persister shut down...");
+ }
+
+ private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) {
+ return new OutgoingMessageEntity(id, true,
+ envelope.getType().getNumber(),
+ envelope.getRelay(),
+ envelope.getTimestamp(),
+ envelope.getSource(),
+ envelope.getSourceDevice(),
+ envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
+ envelope.hasContent() ? envelope.getContent().toByteArray() : null);
+ }
+
+ private static class Key {
+
+ private final byte[] userMessageQueue;
+ private final byte[] userMessageQueueMetadata;
+ private final byte[] userMessageQueuePersistInProgress;
+
+ private final String address;
+ private final long deviceId;
+
+ Key(String address, long deviceId) {
+ this.address = address;
+ this.deviceId = deviceId;
+ this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes();
+ this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes();
+ this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes();
+ }
+
+ String getAddress() {
+ return address;
+ }
+
+ long getDeviceId() {
+ return deviceId;
+ }
+
+ byte[] getUserMessageQueue() {
+ return userMessageQueue;
+ }
+
+ byte[] getUserMessageQueueMetadata() {
+ return userMessageQueueMetadata;
+ }
+
+ byte[] getUserMessageQueuePersistInProgress() {
+ return userMessageQueuePersistInProgress;
+ }
+
+ static byte[] getUserMessageQueueIndex() {
+ return "user_queue_index".getBytes();
+ }
+
+ static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException {
+ try {
+ String[] parts = new String(userMessageQueue).split("::");
+
+ if (parts.length != 3) {
+ throw new IOException("Malformed key: " + new String(userMessageQueue));
+ }
+
+ return new Key(parts[1], Long.parseLong(parts[2]));
+ } catch (NumberFormatException e) {
+ throw new IOException(e);
+ }
+ }
+ }
+
+ private static class InsertOperation {
+ private final LuaScript insert;
+
+ InsertOperation(JedisPool jedisPool) throws IOException {
+ this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
+ }
+
+ public void insert(String destination, long destinationDevice, long timestamp, Envelope message) {
+ Key key = new Key(destination, destinationDevice);
+ String sender = message.getSource() + "::" + message.getTimestamp();
+
+ List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
+ List args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes());
+
+ insert.execute(keys, args);
+ }
+ }
+
+ private static class RemoveOperation {
+
+ private final LuaScript removeById;
+ private final LuaScript removeBySender;
+ private final LuaScript removeQueue;
+
+ RemoveOperation(JedisPool jedisPool) throws IOException {
+ this.removeById = LuaScript.fromResource(jedisPool, "lua/remove_item_by_id.lua" );
+ this.removeBySender = LuaScript.fromResource(jedisPool, "lua/remove_item_by_sender.lua");
+ this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
+ }
+
+ public void remove(String destination, long destinationDevice, long id) {
+ Key key = new Key(destination, destinationDevice);
+
+ List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
+ List args = Collections.singletonList(String.valueOf(id).getBytes());
+
+ this.removeById.execute(keys, args);
+ }
+
+ public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {
+ Key key = new Key(destination, destinationDevice);
+ String senderKey = sender + "::" + timestamp;
+
+ List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
+ List args = Collections.singletonList(senderKey.getBytes());
+
+ return (byte[])this.removeBySender.execute(keys, args);
+ }
+
+ public void clear(String destination, long deviceId) {
+ Key key = new Key(destination, deviceId);
+
+ List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
+ List args = new LinkedList<>();
+
+ this.removeQueue.execute(keys, args);
+ }
+ }
+
+ private static class GetOperation {
+
+ private final LuaScript getQueues;
+ private final LuaScript getItems;
+
+ GetOperation(JedisPool jedisPool) throws IOException {
+ this.getQueues = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua");
+ this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua");
+ }
+
+ List getQueues(byte[] queue, long maxTimeMillis, int limit) {
+ List keys = Collections.singletonList(queue);
+ List args = Arrays.asList(String.valueOf(maxTimeMillis).getBytes(), String.valueOf(limit).getBytes());
+
+ return (List)getQueues.execute(keys, args);
+ }
+
+ List> getItems(byte[] queue, byte[] lock, int limit) {
+ List keys = Arrays.asList(queue, lock);
+ List args = Collections.singletonList(String.valueOf(limit).getBytes());
+
+ Iterator results = ((List) getItems.execute(keys, args)).iterator();
+ List> items = new LinkedList<>();
+
+ while (results.hasNext()) {
+ items.add(new Pair<>(results.next(), Double.valueOf(SafeEncoder.encode(results.next()))));
+ }
+
+ return items;
+ }
+ }
+
+ private static class MessagePersister extends Thread {
+
+ private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
+ private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ private static final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" ));
+ private static final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue"));
+ private static final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" ));
+ private static final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" ));
+ private static final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount"));
+
+ private static final int CHUNK_SIZE = 100;
+
+ private final AtomicBoolean running = new AtomicBoolean(true);
+
+ private final JedisPool jedisPool;
+ private final Messages database;
+ private final long delayTime;
+ private final TimeUnit delayTimeUnit;
+
+ private final PubSubManager pubSubManager;
+ private final PushSender pushSender;
+ private final AccountsManager accountsManager;
+
+ private final GetOperation getOperation;
+ private final RemoveOperation removeOperation;
+
+ private boolean finished = false;
+
+ MessagePersister(JedisPool jedisPool,
+ Messages database,
+ PubSubManager pubSubManager,
+ PushSender pushSender,
+ AccountsManager accountsManager,
+ long delayTime,
+ TimeUnit delayTimeUnit)
+ throws IOException
+ {
+ super(MessagePersister.class.getSimpleName());
+ this.jedisPool = jedisPool;
+ this.database = database;
+
+ this.pubSubManager = pubSubManager;
+ this.pushSender = pushSender;
+ this.accountsManager = accountsManager;
+
+ this.delayTime = delayTime;
+ this.delayTimeUnit = delayTimeUnit;
+ this.getOperation = new GetOperation(jedisPool);
+ this.removeOperation = new RemoveOperation(jedisPool);
+ }
+
+ @Override
+ public void run() {
+ while (running.get()) {
+ try {
+ List queuesToPersist = getQueuesToPersist(getOperation);
+ queueCountHistogram.update(queuesToPersist.size());
+
+ for (byte[] queue : queuesToPersist) {
+ Key key = Key.fromUserMessageQueue(queue);
+
+ persistQueue(jedisPool, key);
+ notifyClients(accountsManager, pubSubManager, pushSender, key);
+ }
+
+ if (queuesToPersist.isEmpty()) {
+ Thread.sleep(10000);
+ }
+ } catch (Throwable t) {
+ logger.error("Exception while persisting: ", t);
+ }
+ }
+
+ synchronized (this) {
+ finished = true;
+ notifyAll();
+ }
+ }
+
+ synchronized void shutdown() {
+ running.set(false);
+ while (!finished) Util.wait(this);
+ }
+
+ private void persistQueue(JedisPool jedisPool, Key key) throws IOException {
+ Timer.Context timer = persistQueueTimer.time();
+
+ int messagesPersistedCount = 0;
+
+ try (Jedis jedis = jedisPool.getResource()) {
+ while (true) {
+ jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes());
+
+ Set messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE);
+
+ for (Tuple message : messages) {
+ persistMessage(key, (long)message.getScore(), message.getBinaryElement());
+ messagesPersistedCount++;
+ }
+
+ if (messages.size() < CHUNK_SIZE) {
+ jedis.del(key.getUserMessageQueuePersistInProgress());
+ return;
+ }
+ }
+ } finally {
+ timer.stop();
+ queueSizeHistogram.update(messagesPersistedCount);
+ }
+ }
+
+ private void persistMessage(Key key, long score, byte[] message) {
+ try {
+ Envelope envelope = Envelope.parseFrom(message);
+ database.store(envelope, key.getAddress(), key.getDeviceId());
+ } catch (InvalidProtocolBufferException e) {
+ logger.error("Error parsing envelope", e);
+ }
+
+ removeOperation.remove(key.getAddress(), key.getDeviceId(), score);
+ }
+
+ private List getQueuesToPersist(GetOperation getOperation) {
+ Timer.Context timer = getQueuesTimer.time();
+ try {
+ long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime);
+ return getOperation.getQueues(Key.getUserMessageQueueIndex(), maxTime, 100);
+ } finally {
+ timer.stop();
+ }
+ }
+
+ private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key)
+ throws IOException
+ {
+ Timer.Context timer = notifyTimer.time();
+
+ try {
+ boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()),
+ PubSubProtos.PubSubMessage.newBuilder()
+ .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
+ .build());
+
+ if (!notified) {
+ Optional account = accountsManager.get(key.getAddress());
+
+ if (account.isPresent()) {
+ Optional device = account.get().getDevice(key.getDeviceId());
+
+ if (device.isPresent()) {
+ try {
+ pushSender.sendQueuedNotification(account.get(), device.get(), false);
+ } catch (NotPushRegisteredException e) {
+ logger.warn("After message persistence, no longer push registered!");
+ } catch (TransientPushFailureException e) {
+ logger.warn("Transient push failure!", e);
+ }
+ }
+ }
+ }
+ } finally {
+ timer.stop();
+ }
+ }
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java
index fff0213bf..fbafcd232 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java
@@ -1,44 +1,118 @@
package org.whispersystems.textsecuregcm.storage;
+import com.codahale.metrics.Meter;
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
+import org.whispersystems.textsecuregcm.util.Constants;
+import org.whispersystems.textsecuregcm.util.Conversions;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
import java.util.List;
+import static com.codahale.metrics.MetricRegistry.name;
+
public class MessagesManager {
- private final Messages messages;
+ private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ private static final Meter cacheHitByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitById" ));
+ private static final Meter cacheMissByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissById" ));
+ private static final Meter cacheHitByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByName" ));
+ private static final Meter cacheMissByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissByName"));
- public MessagesManager(Messages messages) {
- this.messages = messages;
+ private final Messages messages;
+ private final MessagesCache messagesCache;
+ private final Distribution distribution;
+
+ public MessagesManager(Messages messages, MessagesCache messagesCache, float cacheRate) {
+ this.messages = messages;
+ this.messagesCache = messagesCache;
+ this.distribution = new Distribution(cacheRate);
}
- public int insert(String destination, long destinationDevice, Envelope message) {
- return this.messages.store(message, destination, destinationDevice) + 1;
+ public void insert(String destination, long destinationDevice, Envelope message) {
+ if (distribution.isQualified(destination, destinationDevice)) {
+ messagesCache.insert(destination, destinationDevice, message);
+ } else {
+ messages.store(message, destination, destinationDevice);
+ }
}
public OutgoingMessageEntityList getMessagesForDevice(String destination, long destinationDevice) {
List messages = this.messages.load(destination, destinationDevice);
+
+ if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) {
+ messages.addAll(this.messagesCache.get(destination, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size()));
+ }
+
return new OutgoingMessageEntityList(messages, messages.size() >= Messages.RESULT_SET_CHUNK_SIZE);
}
public void clear(String destination) {
+ this.messagesCache.clear(destination);
this.messages.clear(destination);
}
public void clear(String destination, long deviceId) {
+ this.messagesCache.clear(destination, deviceId);
this.messages.clear(destination, deviceId);
}
public Optional delete(String destination, long destinationDevice, String source, long timestamp)
{
- return Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp));
+ Optional removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp);
+
+ if (!removed.isPresent()) {
+ removed = Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp));
+ cacheMissByNameMeter.mark();
+ } else {
+ cacheHitByNameMeter.mark();
+ }
+
+ return removed;
}
- public void delete(String destination, long id) {
- this.messages.remove(destination, id);
+ public void delete(String destination, long deviceId, long id, boolean cached) {
+ if (cached) {
+ this.messagesCache.remove(destination, deviceId, id);
+ cacheHitByIdMeter.mark();
+ } else {
+ this.messages.remove(destination, id);
+ cacheMissByIdMeter.mark();
+ }
}
+
+ public static class Distribution {
+
+ private final float percentage;
+
+ public Distribution(float percentage) {
+ this.percentage = percentage;
+ }
+
+ public boolean isQualified(String address, long device) {
+ if (percentage <= 0) return false;
+ if (percentage >= 100) return true;
+
+ try {
+ MessageDigest digest = MessageDigest.getInstance("SHA1");
+ digest.update(address.getBytes());
+ digest.update(Conversions.longToByteArray(device));
+
+ byte[] result = digest.digest();
+ int hashCode = Conversions.byteArrayToShort(result);
+
+ return hashCode <= 65535 * percentage;
+ } catch (NoSuchAlgorithmException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ }
+
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
index 7fb703c2b..723fd11b3 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -82,7 +82,7 @@ public class WebSocketConnection implements DispatchChannel {
processStoredMessages();
break;
case PubSubMessage.Type.DELIVER_VALUE:
- sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false);
+ sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false);
break;
case PubSubMessage.Type.CONNECTED_VALUE:
if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) {
@@ -106,9 +106,9 @@ public class WebSocketConnection implements DispatchChannel {
processStoredMessages();
}
- private void sendMessage(final Envelope message,
- final Optional storedMessageId,
- final boolean requery)
+ private void sendMessage(final Envelope message,
+ final Optional storedMessageInfo,
+ final boolean requery)
{
try {
EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, device.getSignalingKey());
@@ -125,17 +125,17 @@ public class WebSocketConnection implements DispatchChannel {
}
if (isSuccessResponse(response)) {
- if (storedMessageId.isPresent()) messagesManager.delete(account.getNumber(), storedMessageId.get());
- if (!isReceipt) sendDeliveryReceiptFor(message);
- if (requery) processStoredMessages();
- } else if (!isSuccessResponse(response) && !storedMessageId.isPresent()) {
+ if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached);
+ if (!isReceipt) sendDeliveryReceiptFor(message);
+ if (requery) processStoredMessages();
+ } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) {
requeueMessage(message);
}
}
@Override
public void onFailure(@Nonnull Throwable throwable) {
- if (!storedMessageId.isPresent()) requeueMessage(message);
+ if (!storedMessageInfo.isPresent()) requeueMessage(message);
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
@@ -148,11 +148,12 @@ public class WebSocketConnection implements DispatchChannel {
}
private void requeueMessage(Envelope message) {
- int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message);
- boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT;
+ pushSender.getWebSocketSender().queueMessage(account, device, message);
+
+ boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT;
try {
- pushSender.sendQueuedNotification(account, device, queueDepth, fallback);
+ pushSender.sendQueuedNotification(account, device, fallback);
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("requeueMessage", e);
}
@@ -162,7 +163,7 @@ public class WebSocketConnection implements DispatchChannel {
try {
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp(),
message.hasRelay() ? Optional.of(message.getRelay()) :
- Optional.absent());
+ Optional.absent());
} catch (NoSuchUserException | NotPushRegisteredException e) {
logger.info("No longer registered " + e.getMessage());
} catch(IOException | TransientPushFailureException e) {
@@ -196,11 +197,21 @@ public class WebSocketConnection implements DispatchChannel {
builder.setRelay(message.getRelay());
}
- sendMessage(builder.build(), Optional.of(message.getId()), !iterator.hasNext() && messages.hasMore());
+ sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore());
}
if (!messages.hasMore()) {
client.sendRequest("PUT", "/api/v1/queue/empty", null, Optional.absent());
}
}
+
+ private static class StoredMessageInfo {
+ private final long id;
+ private final boolean cached;
+
+ private StoredMessageInfo(long id, boolean cached) {
+ this.id = id;
+ this.cached = cached;
+ }
+ }
}
diff --git a/src/main/resources/lua/get_items.lua b/src/main/resources/lua/get_items.lua
new file mode 100644
index 000000000..7a966d039
--- /dev/null
+++ b/src/main/resources/lua/get_items.lua
@@ -0,0 +1,10 @@
+-- keys: queue_key, queue_locked_key
+-- argv: limit
+
+local locked = redis.call("GET", KEYS[2])
+
+if locked then
+ return {}
+end
+
+return redis.call("ZRANGE", KEYS[1], 0, ARGV[1], "WITHSCORES")
\ No newline at end of file
diff --git a/src/main/resources/lua/get_queues_to_persist.lua b/src/main/resources/lua/get_queues_to_persist.lua
new file mode 100644
index 000000000..b4afdaf43
--- /dev/null
+++ b/src/main/resources/lua/get_queues_to_persist.lua
@@ -0,0 +1,10 @@
+-- keys: queue_total_index
+-- argv: max_time, limit
+
+local results = redis.call("ZRANGEBYSCORE", KEYS[1], 0, ARGV[1], "LIMIT", 0, ARGV[2])
+
+if results and next(results) then
+ redis.call("ZREM", KEYS[1], unpack(results))
+end
+
+return results
diff --git a/src/main/resources/lua/insert_item.lua b/src/main/resources/lua/insert_item.lua
new file mode 100644
index 000000000..4f36a69df
--- /dev/null
+++ b/src/main/resources/lua/insert_item.lua
@@ -0,0 +1,13 @@
+-- keys: queue_key, queue_metadata_key, queue_total_index
+-- argv: message, current_time, sender_key
+
+local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
+redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1])
+redis.call("HSET", KEYS[2], ARGV[3], messageId)
+redis.call("HSET", KEYS[2], messageId, ARGV[3])
+
+redis.call("EXPIRE", KEYS[1], 7776000)
+redis.call("EXPIRE", KEYS[2], 7776000)
+
+redis.call("ZADD", KEYS[3], "NX", ARGV[2], KEYS[1])
+return messageId
\ No newline at end of file
diff --git a/src/main/resources/lua/remove_item_by_id.lua b/src/main/resources/lua/remove_item_by_id.lua
new file mode 100644
index 000000000..b5e8496ba
--- /dev/null
+++ b/src/main/resources/lua/remove_item_by_id.lua
@@ -0,0 +1,16 @@
+-- keys: queue_key, queue_metadata_key, queue_index
+-- argv: index_to_remove
+
+local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1])
+local senderIndex = redis.call("HGET", KEYS[2], ARGV[1])
+
+if senderIndex then
+ redis.call("HDEL", KEYS[2], senderIndex)
+ redis.call("HDEL", KEYS[2], ARGV[1])
+end
+
+if (redis.call("ZCARD", KEYS[1]) == 0) then
+ redis.call("ZREM", KEYS[3], KEYS[1])
+end
+
+return removedCount > 0
diff --git a/src/main/resources/lua/remove_item_by_sender.lua b/src/main/resources/lua/remove_item_by_sender.lua
new file mode 100644
index 000000000..9aaf7639f
--- /dev/null
+++ b/src/main/resources/lua/remove_item_by_sender.lua
@@ -0,0 +1,22 @@
+-- keys: queue_key, queue_metadata_key, queue_index
+-- argv: sender_to_remove
+
+local messageId = redis.call("HGET", KEYS[2], ARGV[1])
+
+if messageId then
+ local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], messageId, messageId, "LIMIT", 0, 1)
+
+ redis.call("ZREMRANGEBYSCORE", KEYS[1], messageId, messageId)
+ redis.call("HDEL", KEYS[2], ARGV[1])
+ redis.call("HDEL", KEYS[2], messageId)
+
+ if (redis.call("ZCARD", KEYS[1]) == 0) then
+ redis.call("ZREM", KEYS[3], KEYS[1])
+ end
+
+ if envelope and next(envelope) then
+ return envelope[1]
+ end
+end
+
+return nil
diff --git a/src/main/resources/lua/remove_queue.lua b/src/main/resources/lua/remove_queue.lua
new file mode 100644
index 000000000..315a0fa7e
--- /dev/null
+++ b/src/main/resources/lua/remove_queue.lua
@@ -0,0 +1,5 @@
+-- keys: queue_key, queue_metadata_key, queue_index
+
+redis.call("DEL", KEYS[1])
+redis.call("DEL", KEYS[2])
+redis.call("ZREM", KEYS[3], KEYS[1])
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java
index 809d47272..efa222821 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java
@@ -186,8 +186,8 @@ public class MessageControllerTest {
final long timestampTwo = 313388;
List messages = new LinkedList() {{
- add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
- add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
+ add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
+ add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@@ -217,8 +217,8 @@ public class MessageControllerTest {
final long timestampTwo = 313388;
List messages = new LinkedList() {{
- add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
- add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
+ add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
+ add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@@ -239,13 +239,13 @@ public class MessageControllerTest {
public synchronized void testDeleteMessages() throws Exception {
long timestamp = System.currentTimeMillis();
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337))
- .thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
+ .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true,
Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp,
"+14152222222", 1, "hi".getBytes(), null)));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338))
- .thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
+ .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true,
Envelope.Type.RECEIPT_VALUE,
null, System.currentTimeMillis(),
"+14152222222", 1, null, null)));
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
index 949e5b1e5..2e10578c3 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
@@ -100,9 +100,9 @@ public class WebSocketConnectionTest {
MessagesManager storedMessages = mock(MessagesManager.class);
List outgoingMessages = new LinkedList () {{
- add(createMessage(1L, "sender1", 1111, false, "first"));
- add(createMessage(2L, "sender1", 2222, false, "second"));
- add(createMessage(3L, "sender2", 3333, false, "third"));
+ add(createMessage(1L, false, "sender1", 1111, false, "first"));
+ add(createMessage(2L, false, "sender1", 2222, false, "second"));
+ add(createMessage(3L, false, "sender2", 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
@@ -157,7 +157,7 @@ public class WebSocketConnectionTest {
futures.get(0).setException(new IOException());
futures.get(2).setException(new IOException());
- verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L));
+ verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L), eq(2L), eq(false));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L), eq(Optional.absent()));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
@@ -170,7 +170,6 @@ public class WebSocketConnectionTest {
WebsocketSender websocketSender = mock(WebsocketSender.class);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
- when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10);
Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
@@ -251,7 +250,7 @@ public class WebSocketConnectionTest {
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.absent()));
verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class));
- verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(10), eq(true));
+ verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(true));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
@@ -266,7 +265,6 @@ public class WebSocketConnectionTest {
reset(pushSender);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
- when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10);
final Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
@@ -285,11 +283,11 @@ public class WebSocketConnectionTest {
.build();
List pendingMessages = new LinkedList() {{
- add(new OutgoingMessageEntity(1, firstMessage.getType().getNumber(), firstMessage.getRelay(),
+ add(new OutgoingMessageEntity(1, true, firstMessage.getType().getNumber(), firstMessage.getRelay(),
firstMessage.getTimestamp(), firstMessage.getSource(),
firstMessage.getSourceDevice(), firstMessage.getLegacyMessage().toByteArray(),
firstMessage.getContent().toByteArray()));
- add(new OutgoingMessageEntity(2, secondMessage.getType().getNumber(), secondMessage.getRelay(),
+ add(new OutgoingMessageEntity(2, false, secondMessage.getType().getNumber(), secondMessage.getRelay(),
secondMessage.getTimestamp(), secondMessage.getSource(),
secondMessage.getSourceDevice(), secondMessage.getLegacyMessage().toByteArray(),
secondMessage.getContent().toByteArray()));
@@ -355,8 +353,8 @@ public class WebSocketConnectionTest {
}
- private OutgoingMessageEntity createMessage(long id, String sender, long timestamp, boolean receipt, String content) {
- return new OutgoingMessageEntity(id, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
+ private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, long timestamp, boolean receipt, String content) {
+ return new OutgoingMessageEntity(id, cached, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, 1, content.getBytes(), null);
}