Support for messagedb caching

This commit is contained in:
Moxie Marlinspike 2018-04-10 11:35:55 -07:00
parent 35d6bfb3a8
commit 9923a07c25
20 changed files with 857 additions and 89 deletions

View File

@ -78,7 +78,7 @@
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.7.3</version>
<version>2.9.0</version>
<type>jar</type>
<scope>compile</scope>
</dependency>

View File

@ -21,6 +21,7 @@ import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.FederationConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.ProfilesConfiguration;
import org.whispersystems.textsecuregcm.configuration.PushConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
@ -75,6 +76,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private RedisConfiguration directory;
@NotNull
@Valid
@JsonProperty
private MessageCacheConfiguration messageCache;
@Valid
@NotNull
@JsonProperty
@ -160,6 +166,10 @@ public class WhisperServerConfiguration extends Configuration {
return directory;
}
public MessageCacheConfiguration getMessageCacheConfiguration() {
return messageCache;
}
public DataSourceFactory getMessageStoreConfiguration() {
return messageStore;
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -24,7 +24,6 @@ import com.google.common.base.Optional;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.dropwizard.simpleauth.AuthDynamicFeature;
import org.whispersystems.dropwizard.simpleauth.AuthValueFactoryProvider;
@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.Messages;
import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
@ -161,15 +161,17 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl());
JedisPool cacheClient = cacheClientFactory.getRedisClientPool();
JedisPool directoryClient = new RedisClientFactory(config.getDirectoryConfiguration().getUrl()).getRedisClientPool();
JedisPool messagesClient = new RedisClientFactory(config.getMessageCacheConfiguration().getRedisConfiguration().getUrl()).getRedisClientPool();
DirectoryManager directory = new DirectoryManager(directoryClient);
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheClient);
PendingDevicesManager pendingDevicesManager = new PendingDevicesManager (pendingDevices, cacheClient );
AccountsManager accountsManager = new AccountsManager(accounts, directory, cacheClient);
FederatedClientManager federatedClientManager = new FederatedClientManager(environment, config.getJerseyClientConfiguration(), config.getFederationConfiguration());
MessagesManager messagesManager = new MessagesManager(messages);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes());
MessagesManager messagesManager = new MessagesManager(messages, messagesCache, config.getMessageCacheConfiguration().getCacheRate());
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);
DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.<DispatchChannel>of(deadLetterHandler));
DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.of(deadLetterHandler));
PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
APNSender apnSender = new APNSender(accountsManager, config.getApnConfiguration());
GCMSender gcmSender = new GCMSender(accountsManager, config.getGcmConfiguration().getApiKey());
@ -186,10 +188,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ReceiptSender receiptSender = new ReceiptSender(accountsManager, pushSender, federatedClientManager);
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(config.getTurnConfiguration());
messagesCache.setPubSubManager(pubSubManager, pushSender);
apnSender.setApnFallbackManager(apnFallbackManager);
environment.lifecycle().manage(apnFallbackManager);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(pushSender);
environment.lifecycle().manage(messagesCache);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager);

View File

@ -0,0 +1,38 @@
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.glassfish.jersey.server.JSONP;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
public class MessageCacheConfiguration {
@JsonProperty
@NotNull
@Valid
private RedisConfiguration redis;
@JsonProperty
private int persistDelayMinutes = 10;
@JsonProperty
@Min(0)
@Max(1)
private float cacheRate = 1;
public RedisConfiguration getRedisConfiguration() {
return redis;
}
public int getPersistDelayMinutes() {
return persistDelayMinutes;
}
public float getCacheRate() {
return cacheRate;
}
}

View File

@ -8,6 +8,9 @@ public class OutgoingMessageEntity {
@JsonIgnore
private long id;
@JsonIgnore
private boolean cached;
@JsonProperty
private int type;
@ -31,11 +34,12 @@ public class OutgoingMessageEntity {
public OutgoingMessageEntity() {}
public OutgoingMessageEntity(long id, int type, String relay, long timestamp,
public OutgoingMessageEntity(long id, boolean cached, int type, String relay, long timestamp,
String source, int sourceDevice, byte[] message,
byte[] content)
{
this.id = id;
this.cached = cached;
this.type = type;
this.relay = relay;
this.timestamp = timestamp;
@ -73,8 +77,14 @@ public class OutgoingMessageEntity {
return content;
}
@JsonIgnore
public long getId() {
return id;
}
@JsonIgnore
public boolean isCached() {
return cached;
}
}

View File

@ -39,7 +39,7 @@ public class PushSender implements Managed {
private final Logger logger = LoggerFactory.getLogger(PushSender.class);
public static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"badge\":%d,\"alert\":{\"loc-key\":\"APN_Message\"}}}";
private static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}";
private final ApnFallbackManager apnFallbackManager;
private final GCMSender gcmSender;
@ -61,12 +61,7 @@ public class PushSender implements Managed {
SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME)
.register(name(PushSender.class, "send_queue_depth"),
new Gauge<Integer>() {
@Override
public Integer getValue() {
return executor.getSize();
}
});
(Gauge<Integer>) executor::getSize);
}
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean silent)
@ -77,22 +72,17 @@ public class PushSender implements Managed {
}
if (queueSize > 0) {
executor.execute(new Runnable() {
@Override
public void run() {
sendSynchronousMessage(account, device, message, silent);
}
});
executor.execute(() -> sendSynchronousMessage(account, device, message, silent));
} else {
sendSynchronousMessage(account, device, message, silent);
}
}
public void sendQueuedNotification(Account account, Device device, int messageQueueDepth, boolean fallback)
public void sendQueuedNotification(Account account, Device device, boolean fallback)
throws NotPushRegisteredException, TransientPushFailureException
{
if (device.getGcmId() != null) sendGcmNotification(account, device);
else if (device.getApnId() != null) sendApnNotification(account, device, messageQueueDepth, fallback);
else if (device.getApnId() != null) sendApnNotification(account, device, fallback);
else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!");
}
@ -127,16 +117,15 @@ public class PushSender implements Managed {
if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) {
boolean fallback = !silent && !outgoingMessage.getSource().equals(account.getNumber());
sendApnNotification(account, device, deliveryStatus.getMessageQueueDepth(), fallback);
sendApnNotification(account, device, fallback);
}
}
private void sendApnNotification(Account account, Device device, int messageQueueDepth, boolean fallback) {
private void sendApnNotification(Account account, Device device, boolean fallback) {
ApnMessage apnMessage;
if (!Util.isEmpty(device.getVoipApnId())) {
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(),
String.format(APN_PAYLOAD, messageQueueDepth), true,
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, true,
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION));
if (fallback) {
@ -144,8 +133,7 @@ public class PushSender implements Managed {
new ApnFallbackTask(device.getApnId(), device.getVoipApnId(), apnMessage));
}
} else {
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(),
String.format(APN_PAYLOAD, messageQueueDepth),
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD,
false, ApnMessage.MAX_EXPIRATION);
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -36,12 +36,13 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag
public class WebsocketSender {
public static enum Type {
public enum Type {
APN,
GCM,
WEB
}
@SuppressWarnings("unused")
private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
@ -79,28 +80,26 @@ public class WebsocketSender {
else if (channel == Type.GCM) gcmOnlineMeter.mark();
else websocketOnlineMeter.mark();
return new DeliveryStatus(true, 0);
return new DeliveryStatus(true);
} else {
if (channel == Type.APN) apnOfflineMeter.mark();
else if (channel == Type.GCM) gcmOfflineMeter.mark();
else websocketOfflineMeter.mark();
int queueDepth = queueMessage(account, device, message);
return new DeliveryStatus(false, queueDepth);
queueMessage(account, device, message);
return new DeliveryStatus(false);
}
}
public int queueMessage(Account account, Device device, Envelope message) {
public void queueMessage(Account account, Device device, Envelope message) {
websocketRequeueMeter.mark();
WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
int queueDepth = messagesManager.insert(account.getNumber(), device.getId(), message);
WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
messagesManager.insert(account.getNumber(), device.getId(), message);
pubSubManager.publish(address, PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.QUERY_DB)
.build());
return queueDepth;
}
public boolean sendProvisioningMessage(ProvisioningAddress address, byte[] body) {
@ -118,22 +117,17 @@ public class WebsocketSender {
}
}
public static class DeliveryStatus {
static class DeliveryStatus {
private final boolean delivered;
private final int messageQueueDepth;
public DeliveryStatus(boolean delivered, int messageQueueDepth) {
DeliveryStatus(boolean delivered) {
this.delivered = delivered;
this.messageQueueDepth = messageQueueDepth;
}
public boolean isDelivered() {
boolean isDelivered() {
return delivered;
}
public int getMessageQueueDepth() {
return messageQueueDepth;
}
}
}

View File

@ -0,0 +1,64 @@
package org.whispersystems.textsecuregcm.redis;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.List;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.exceptions.JedisDataException;
public class LuaScript {
private final JedisPool jedisPool;
private final String script;
private final byte[] sha;
public static LuaScript fromResource(JedisPool jedisPool, String resource) throws IOException {
InputStream inputStream = LuaScript.class.getClassLoader().getResourceAsStream(resource);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int read;
while ((read = inputStream.read(buffer)) != -1) {
baos.write(buffer, 0, read);
}
inputStream.close();
baos.close();
return new LuaScript(jedisPool, new String(baos.toByteArray()));
}
private LuaScript(JedisPool jedisPool, String script) {
this.jedisPool = jedisPool;
this.script = script;
this.sha = storeScript(jedisPool, script).getBytes();
}
public Object execute(List<byte[]> keys, List<byte[]> args) {
try (Jedis jedis = jedisPool.getResource()) {
try {
return jedis.evalsha(sha, keys, args);
} catch (JedisDataException e) {
storeScript(jedisPool, script);
return jedis.evalsha(sha, keys, args);
}
}
}
private String storeScript(JedisPool jedisPool, String script) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.scriptLoad(script);
}
}
}

View File

@ -25,7 +25,7 @@ import java.util.List;
public abstract class Messages {
public static final int RESULT_SET_CHUNK_SIZE = 100;
static final int RESULT_SET_CHUNK_SIZE = 100;
private static final String ID = "id";
private static final String TYPE = "type";
@ -38,10 +38,9 @@ public abstract class Messages {
private static final String MESSAGE = "message";
private static final String CONTENT = "content";
@SqlQuery("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
"VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content) " +
"RETURNING (SELECT COUNT(id) FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + TYPE + " != " + Envelope.Type.RECEIPT_VALUE + ")")
abstract int store(@MessageBinder Envelope message,
@SqlUpdate("INSERT INTO messages (" + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SOURCE + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " +
"VALUES (:type, :relay, :timestamp, :source, :source_device, :destination, :destination_device, :message, :content)")
abstract void store(@MessageBinder Envelope message,
@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
@ -100,6 +99,7 @@ public abstract class Messages {
}
return new OutgoingMessageEntity(resultSet.getLong(ID),
false,
type,
resultSet.getString(RELAY),
resultSet.getLong(TIMESTAMP),

View File

@ -0,0 +1,500 @@
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.common.base.Optional;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
public class MessagesCache implements Managed {
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" ));
private static final Timer removeByIdTimer = metricRegistry.timer(name(MessagesCache.class, "removeById" ));
private static final Timer removeByNameTimer = metricRegistry.timer(name(MessagesCache.class, "removeByName"));
private static final Timer getTimer = metricRegistry.timer(name(MessagesCache.class, "get" ));
private static final Timer clearAccountTimer = metricRegistry.timer(name(MessagesCache.class, "clearAccount"));
private static final Timer clearDeviceTimer = metricRegistry.timer(name(MessagesCache.class, "clearDevice" ));
private final JedisPool jedisPool;
private final Messages database;
private final AccountsManager accountsManager;
private final int delayMinutes;
private InsertOperation insertOperation;
private RemoveOperation removeOperation;
private GetOperation getOperation;
private PubSubManager pubSubManager;
private PushSender pushSender;
private MessagePersister messagePersister;
public MessagesCache(JedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
}
public void insert(String destination, long destinationDevice, Envelope message) {
Timer.Context timer = insertTimer.time();
try {
insertOperation.insert(destination, destinationDevice, System.currentTimeMillis(), message);
} finally {
timer.stop();
}
}
public void remove(String destination, long destinationDevice, long id) {
Timer.Context timer = removeByIdTimer.time();
try {
removeOperation.remove(destination, destinationDevice, id);
} finally {
timer.stop();
}
}
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String sender, long timestamp) {
Timer.Context timer = removeByNameTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
if (serialized != null) {
Envelope envelope = Envelope.parseFrom(serialized);
return Optional.of(constructEntityFromEnvelope(0, envelope));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
} finally {
timer.stop();
}
return Optional.absent();
}
public List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit) {
Timer.Context timer = getTimer.time();
try {
List<OutgoingMessageEntity> results = new LinkedList<>();
Key key = new Key(destination, destinationDevice);
List<Pair<byte[], Double>> items = getOperation.getItems(key.getUserMessageQueue(), key.getUserMessageQueuePersistInProgress(), limit);
for (Pair<byte[], Double> item : items) {
try {
long id = item.second().longValue();
Envelope message = Envelope.parseFrom(item.first());
results.add(constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
return results;
} finally {
timer.stop();
}
}
public void clear(String destination) {
Timer.Context timer = clearAccountTimer.time();
try {
for (int i = 1; i < 255; i++) {
clear(destination, i);
}
} finally {
timer.stop();
}
}
public void clear(String destination, long deviceId) {
Timer.Context timer = clearDeviceTimer.time();
try {
removeOperation.clear(destination, deviceId);
} finally {
timer.stop();
}
}
public void setPubSubManager(PubSubManager pubSubManager, PushSender pushSender) {
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
}
@Override
public void start() throws Exception {
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
this.messagePersister.start();
}
@Override
public void stop() throws Exception {
messagePersister.shutdown();
logger.info("Message persister shut down...");
}
private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null);
}
private static class Key {
private final byte[] userMessageQueue;
private final byte[] userMessageQueueMetadata;
private final byte[] userMessageQueuePersistInProgress;
private final String address;
private final long deviceId;
Key(String address, long deviceId) {
this.address = address;
this.deviceId = deviceId;
this.userMessageQueue = ("user_queue::" + address + "::" + deviceId).getBytes();
this.userMessageQueueMetadata = ("user_queue_metadata::" + address + "::" + deviceId).getBytes();
this.userMessageQueuePersistInProgress = ("user_queue_persisting::" + address + "::" + deviceId).getBytes();
}
String getAddress() {
return address;
}
long getDeviceId() {
return deviceId;
}
byte[] getUserMessageQueue() {
return userMessageQueue;
}
byte[] getUserMessageQueueMetadata() {
return userMessageQueueMetadata;
}
byte[] getUserMessageQueuePersistInProgress() {
return userMessageQueuePersistInProgress;
}
static byte[] getUserMessageQueueIndex() {
return "user_queue_index".getBytes();
}
static Key fromUserMessageQueue(byte[] userMessageQueue) throws IOException {
try {
String[] parts = new String(userMessageQueue).split("::");
if (parts.length != 3) {
throw new IOException("Malformed key: " + new String(userMessageQueue));
}
return new Key(parts[1], Long.parseLong(parts[2]));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
}
private static class InsertOperation {
private final LuaScript insert;
InsertOperation(JedisPool jedisPool) throws IOException {
this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
}
public void insert(String destination, long destinationDevice, long timestamp, Envelope message) {
Key key = new Key(destination, destinationDevice);
String sender = message.getSource() + "::" + message.getTimestamp();
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes());
insert.execute(keys, args);
}
}
private static class RemoveOperation {
private final LuaScript removeById;
private final LuaScript removeBySender;
private final LuaScript removeQueue;
RemoveOperation(JedisPool jedisPool) throws IOException {
this.removeById = LuaScript.fromResource(jedisPool, "lua/remove_item_by_id.lua" );
this.removeBySender = LuaScript.fromResource(jedisPool, "lua/remove_item_by_sender.lua");
this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
}
public void remove(String destination, long destinationDevice, long id) {
Key key = new Key(destination, destinationDevice);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(String.valueOf(id).getBytes());
this.removeById.execute(keys, args);
}
public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {
Key key = new Key(destination, destinationDevice);
String senderKey = sender + "::" + timestamp;
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(senderKey.getBytes());
return (byte[])this.removeBySender.execute(keys, args);
}
public void clear(String destination, long deviceId) {
Key key = new Key(destination, deviceId);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = new LinkedList<>();
this.removeQueue.execute(keys, args);
}
}
private static class GetOperation {
private final LuaScript getQueues;
private final LuaScript getItems;
GetOperation(JedisPool jedisPool) throws IOException {
this.getQueues = LuaScript.fromResource(jedisPool, "lua/get_queues_to_persist.lua");
this.getItems = LuaScript.fromResource(jedisPool, "lua/get_items.lua");
}
List<byte[]> getQueues(byte[] queue, long maxTimeMillis, int limit) {
List<byte[]> keys = Collections.singletonList(queue);
List<byte[]> args = Arrays.asList(String.valueOf(maxTimeMillis).getBytes(), String.valueOf(limit).getBytes());
return (List<byte[]>)getQueues.execute(keys, args);
}
List<Pair<byte[], Double>> getItems(byte[] queue, byte[] lock, int limit) {
List<byte[]> keys = Arrays.asList(queue, lock);
List<byte[]> args = Collections.singletonList(String.valueOf(limit).getBytes());
Iterator<byte[]> results = ((List<byte[]>) getItems.execute(keys, args)).iterator();
List<Pair<byte[], Double>> items = new LinkedList<>();
while (results.hasNext()) {
items.add(new Pair<>(results.next(), Double.valueOf(SafeEncoder.encode(results.next()))));
}
return items;
}
}
private static class MessagePersister extends Thread {
private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer getQueuesTimer = metricRegistry.timer(name(MessagesCache.class, "getQueues" ));
private static final Timer persistQueueTimer = metricRegistry.timer(name(MessagesCache.class, "persistQueue"));
private static final Timer notifyTimer = metricRegistry.timer(name(MessagesCache.class, "notifyUser" ));
private static final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueSize" ));
private static final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagesCache.class, "persistQueueCount"));
private static final int CHUNK_SIZE = 100;
private final AtomicBoolean running = new AtomicBoolean(true);
private final JedisPool jedisPool;
private final Messages database;
private final long delayTime;
private final TimeUnit delayTimeUnit;
private final PubSubManager pubSubManager;
private final PushSender pushSender;
private final AccountsManager accountsManager;
private final GetOperation getOperation;
private final RemoveOperation removeOperation;
private boolean finished = false;
MessagePersister(JedisPool jedisPool,
Messages database,
PubSubManager pubSubManager,
PushSender pushSender,
AccountsManager accountsManager,
long delayTime,
TimeUnit delayTimeUnit)
throws IOException
{
super(MessagePersister.class.getSimpleName());
this.jedisPool = jedisPool;
this.database = database;
this.pubSubManager = pubSubManager;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.delayTime = delayTime;
this.delayTimeUnit = delayTimeUnit;
this.getOperation = new GetOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
}
@Override
public void run() {
while (running.get()) {
try {
List<byte[]> queuesToPersist = getQueuesToPersist(getOperation);
queueCountHistogram.update(queuesToPersist.size());
for (byte[] queue : queuesToPersist) {
Key key = Key.fromUserMessageQueue(queue);
persistQueue(jedisPool, key);
notifyClients(accountsManager, pubSubManager, pushSender, key);
}
if (queuesToPersist.isEmpty()) {
Thread.sleep(10000);
}
} catch (Throwable t) {
logger.error("Exception while persisting: ", t);
}
}
synchronized (this) {
finished = true;
notifyAll();
}
}
synchronized void shutdown() {
running.set(false);
while (!finished) Util.wait(this);
}
private void persistQueue(JedisPool jedisPool, Key key) throws IOException {
Timer.Context timer = persistQueueTimer.time();
int messagesPersistedCount = 0;
try (Jedis jedis = jedisPool.getResource()) {
while (true) {
jedis.setex(key.getUserMessageQueuePersistInProgress(), 30, "1".getBytes());
Set<Tuple> messages = jedis.zrangeWithScores(key.getUserMessageQueue(), 0, CHUNK_SIZE);
for (Tuple message : messages) {
persistMessage(key, (long)message.getScore(), message.getBinaryElement());
messagesPersistedCount++;
}
if (messages.size() < CHUNK_SIZE) {
jedis.del(key.getUserMessageQueuePersistInProgress());
return;
}
}
} finally {
timer.stop();
queueSizeHistogram.update(messagesPersistedCount);
}
}
private void persistMessage(Key key, long score, byte[] message) {
try {
Envelope envelope = Envelope.parseFrom(message);
database.store(envelope, key.getAddress(), key.getDeviceId());
} catch (InvalidProtocolBufferException e) {
logger.error("Error parsing envelope", e);
}
removeOperation.remove(key.getAddress(), key.getDeviceId(), score);
}
private List<byte[]> getQueuesToPersist(GetOperation getOperation) {
Timer.Context timer = getQueuesTimer.time();
try {
long maxTime = System.currentTimeMillis() - delayTimeUnit.toMillis(delayTime);
return getOperation.getQueues(Key.getUserMessageQueueIndex(), maxTime, 100);
} finally {
timer.stop();
}
}
private void notifyClients(AccountsManager accountsManager, PubSubManager pubSubManager, PushSender pushSender, Key key)
throws IOException
{
Timer.Context timer = notifyTimer.time();
try {
boolean notified = pubSubManager.publish(new WebsocketAddress(key.getAddress(), key.getDeviceId()),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.QUERY_DB)
.build());
if (!notified) {
Optional<Account> account = accountsManager.get(key.getAddress());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(key.getDeviceId());
if (device.isPresent()) {
try {
pushSender.sendQueuedNotification(account.get(), device.get(), false);
} catch (NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!");
} catch (TransientPushFailureException e) {
logger.warn("Transient push failure!", e);
}
}
}
}
} finally {
timer.stop();
}
}
}
}

View File

@ -1,44 +1,118 @@
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Conversions;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import static com.codahale.metrics.MetricRegistry.name;
public class MessagesManager {
private final Messages messages;
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter cacheHitByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitById" ));
private static final Meter cacheMissByIdMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissById" ));
private static final Meter cacheHitByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByName" ));
private static final Meter cacheMissByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissByName"));
public MessagesManager(Messages messages) {
this.messages = messages;
private final Messages messages;
private final MessagesCache messagesCache;
private final Distribution distribution;
public MessagesManager(Messages messages, MessagesCache messagesCache, float cacheRate) {
this.messages = messages;
this.messagesCache = messagesCache;
this.distribution = new Distribution(cacheRate);
}
public int insert(String destination, long destinationDevice, Envelope message) {
return this.messages.store(message, destination, destinationDevice) + 1;
public void insert(String destination, long destinationDevice, Envelope message) {
if (distribution.isQualified(destination, destinationDevice)) {
messagesCache.insert(destination, destinationDevice, message);
} else {
messages.store(message, destination, destinationDevice);
}
}
public OutgoingMessageEntityList getMessagesForDevice(String destination, long destinationDevice) {
List<OutgoingMessageEntity> messages = this.messages.load(destination, destinationDevice);
if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) {
messages.addAll(this.messagesCache.get(destination, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size()));
}
return new OutgoingMessageEntityList(messages, messages.size() >= Messages.RESULT_SET_CHUNK_SIZE);
}
public void clear(String destination) {
this.messagesCache.clear(destination);
this.messages.clear(destination);
}
public void clear(String destination, long deviceId) {
this.messagesCache.clear(destination, deviceId);
this.messages.clear(destination, deviceId);
}
public Optional<OutgoingMessageEntity> delete(String destination, long destinationDevice, String source, long timestamp)
{
return Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp));
Optional<OutgoingMessageEntity> removed = this.messagesCache.remove(destination, destinationDevice, source, timestamp);
if (!removed.isPresent()) {
removed = Optional.fromNullable(this.messages.remove(destination, destinationDevice, source, timestamp));
cacheMissByNameMeter.mark();
} else {
cacheHitByNameMeter.mark();
}
return removed;
}
public void delete(String destination, long id) {
this.messages.remove(destination, id);
public void delete(String destination, long deviceId, long id, boolean cached) {
if (cached) {
this.messagesCache.remove(destination, deviceId, id);
cacheHitByIdMeter.mark();
} else {
this.messages.remove(destination, id);
cacheMissByIdMeter.mark();
}
}
public static class Distribution {
private final float percentage;
public Distribution(float percentage) {
this.percentage = percentage;
}
public boolean isQualified(String address, long device) {
if (percentage <= 0) return false;
if (percentage >= 100) return true;
try {
MessageDigest digest = MessageDigest.getInstance("SHA1");
digest.update(address.getBytes());
digest.update(Conversions.longToByteArray(device));
byte[] result = digest.digest();
int hashCode = Conversions.byteArrayToShort(result);
return hashCode <= 65535 * percentage;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}
}

View File

@ -82,7 +82,7 @@ public class WebSocketConnection implements DispatchChannel {
processStoredMessages();
break;
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.<Long>absent(), false);
sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false);
break;
case PubSubMessage.Type.CONNECTED_VALUE:
if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) {
@ -106,9 +106,9 @@ public class WebSocketConnection implements DispatchChannel {
processStoredMessages();
}
private void sendMessage(final Envelope message,
final Optional<Long> storedMessageId,
final boolean requery)
private void sendMessage(final Envelope message,
final Optional<StoredMessageInfo> storedMessageInfo,
final boolean requery)
{
try {
EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, device.getSignalingKey());
@ -125,17 +125,17 @@ public class WebSocketConnection implements DispatchChannel {
}
if (isSuccessResponse(response)) {
if (storedMessageId.isPresent()) messagesManager.delete(account.getNumber(), storedMessageId.get());
if (!isReceipt) sendDeliveryReceiptFor(message);
if (requery) processStoredMessages();
} else if (!isSuccessResponse(response) && !storedMessageId.isPresent()) {
if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached);
if (!isReceipt) sendDeliveryReceiptFor(message);
if (requery) processStoredMessages();
} else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) {
requeueMessage(message);
}
}
@Override
public void onFailure(@Nonnull Throwable throwable) {
if (!storedMessageId.isPresent()) requeueMessage(message);
if (!storedMessageInfo.isPresent()) requeueMessage(message);
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
@ -148,11 +148,12 @@ public class WebSocketConnection implements DispatchChannel {
}
private void requeueMessage(Envelope message) {
int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message);
boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT;
pushSender.getWebSocketSender().queueMessage(account, device, message);
boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT;
try {
pushSender.sendQueuedNotification(account, device, queueDepth, fallback);
pushSender.sendQueuedNotification(account, device, fallback);
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("requeueMessage", e);
}
@ -162,7 +163,7 @@ public class WebSocketConnection implements DispatchChannel {
try {
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp(),
message.hasRelay() ? Optional.of(message.getRelay()) :
Optional.<String>absent());
Optional.absent());
} catch (NoSuchUserException | NotPushRegisteredException e) {
logger.info("No longer registered " + e.getMessage());
} catch(IOException | TransientPushFailureException e) {
@ -196,11 +197,21 @@ public class WebSocketConnection implements DispatchChannel {
builder.setRelay(message.getRelay());
}
sendMessage(builder.build(), Optional.of(message.getId()), !iterator.hasNext() && messages.hasMore());
sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore());
}
if (!messages.hasMore()) {
client.sendRequest("PUT", "/api/v1/queue/empty", null, Optional.<byte[]>absent());
}
}
private static class StoredMessageInfo {
private final long id;
private final boolean cached;
private StoredMessageInfo(long id, boolean cached) {
this.id = id;
this.cached = cached;
}
}
}

View File

@ -0,0 +1,10 @@
-- keys: queue_key, queue_locked_key
-- argv: limit
local locked = redis.call("GET", KEYS[2])
if locked then
return {}
end
return redis.call("ZRANGE", KEYS[1], 0, ARGV[1], "WITHSCORES")

View File

@ -0,0 +1,10 @@
-- keys: queue_total_index
-- argv: max_time, limit
local results = redis.call("ZRANGEBYSCORE", KEYS[1], 0, ARGV[1], "LIMIT", 0, ARGV[2])
if results and next(results) then
redis.call("ZREM", KEYS[1], unpack(results))
end
return results

View File

@ -0,0 +1,13 @@
-- keys: queue_key, queue_metadata_key, queue_total_index
-- argv: message, current_time, sender_key
local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1])
redis.call("HSET", KEYS[2], ARGV[3], messageId)
redis.call("HSET", KEYS[2], messageId, ARGV[3])
redis.call("EXPIRE", KEYS[1], 7776000)
redis.call("EXPIRE", KEYS[2], 7776000)
redis.call("ZADD", KEYS[3], "NX", ARGV[2], KEYS[1])
return messageId

View File

@ -0,0 +1,16 @@
-- keys: queue_key, queue_metadata_key, queue_index
-- argv: index_to_remove
local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1])
local senderIndex = redis.call("HGET", KEYS[2], ARGV[1])
if senderIndex then
redis.call("HDEL", KEYS[2], senderIndex)
redis.call("HDEL", KEYS[2], ARGV[1])
end
if (redis.call("ZCARD", KEYS[1]) == 0) then
redis.call("ZREM", KEYS[3], KEYS[1])
end
return removedCount > 0

View File

@ -0,0 +1,22 @@
-- keys: queue_key, queue_metadata_key, queue_index
-- argv: sender_to_remove
local messageId = redis.call("HGET", KEYS[2], ARGV[1])
if messageId then
local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], messageId, messageId, "LIMIT", 0, 1)
redis.call("ZREMRANGEBYSCORE", KEYS[1], messageId, messageId)
redis.call("HDEL", KEYS[2], ARGV[1])
redis.call("HDEL", KEYS[2], messageId)
if (redis.call("ZCARD", KEYS[1]) == 0) then
redis.call("ZREM", KEYS[3], KEYS[1])
end
if envelope and next(envelope) then
return envelope[1]
end
end
return nil

View File

@ -0,0 +1,5 @@
-- keys: queue_key, queue_metadata_key, queue_index
redis.call("DEL", KEYS[1])
redis.call("DEL", KEYS[2])
redis.call("ZREM", KEYS[3], KEYS[1])

View File

@ -186,8 +186,8 @@ public class MessageControllerTest {
final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -217,8 +217,8 @@ public class MessageControllerTest {
final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1L, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
add(new OutgoingMessageEntity(2L, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
add(new OutgoingMessageEntity(1L, false, Envelope.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes(), null));
add(new OutgoingMessageEntity(2L, false, Envelope.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null, null));
}};
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
@ -239,13 +239,13 @@ public class MessageControllerTest {
public synchronized void testDeleteMessages() throws Exception {
long timestamp = System.currentTimeMillis();
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31337))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true,
Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp,
"+14152222222", 1, "hi".getBytes(), null)));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, 1, "+14152222222", 31338))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true,
Envelope.Type.RECEIPT_VALUE,
null, System.currentTimeMillis(),
"+14152222222", 1, null, null)));

View File

@ -100,9 +100,9 @@ public class WebSocketConnectionTest {
MessagesManager storedMessages = mock(MessagesManager.class);
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage(1L, "sender1", 1111, false, "first"));
add(createMessage(2L, "sender1", 2222, false, "second"));
add(createMessage(3L, "sender2", 3333, false, "third"));
add(createMessage(1L, false, "sender1", 1111, false, "first"));
add(createMessage(2L, false, "sender1", 2222, false, "second"));
add(createMessage(3L, false, "sender2", 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
@ -157,7 +157,7 @@ public class WebSocketConnectionTest {
futures.get(0).setException(new IOException());
futures.get(2).setException(new IOException());
verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L));
verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L), eq(2L), eq(false));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L), eq(Optional.<String>absent()));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
@ -170,7 +170,6 @@ public class WebSocketConnectionTest {
WebsocketSender websocketSender = mock(WebsocketSender.class);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10);
Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
@ -251,7 +250,7 @@ public class WebSocketConnectionTest {
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.<String>absent()));
verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class));
verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(10), eq(true));
verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(true));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
@ -266,7 +265,6 @@ public class WebSocketConnectionTest {
reset(pushSender);
when(pushSender.getWebSocketSender()).thenReturn(websocketSender);
when(websocketSender.queueMessage(any(Account.class), any(Device.class), any(Envelope.class))).thenReturn(10);
final Envelope firstMessage = Envelope.newBuilder()
.setLegacyMessage(ByteString.copyFrom("first".getBytes()))
@ -285,11 +283,11 @@ public class WebSocketConnectionTest {
.build();
List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1, firstMessage.getType().getNumber(), firstMessage.getRelay(),
add(new OutgoingMessageEntity(1, true, firstMessage.getType().getNumber(), firstMessage.getRelay(),
firstMessage.getTimestamp(), firstMessage.getSource(),
firstMessage.getSourceDevice(), firstMessage.getLegacyMessage().toByteArray(),
firstMessage.getContent().toByteArray()));
add(new OutgoingMessageEntity(2, secondMessage.getType().getNumber(), secondMessage.getRelay(),
add(new OutgoingMessageEntity(2, false, secondMessage.getType().getNumber(), secondMessage.getRelay(),
secondMessage.getTimestamp(), secondMessage.getSource(),
secondMessage.getSourceDevice(), secondMessage.getLegacyMessage().toByteArray(),
secondMessage.getContent().toByteArray()));
@ -355,8 +353,8 @@ public class WebSocketConnectionTest {
}
private OutgoingMessageEntity createMessage(long id, String sender, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, 1, content.getBytes(), null);
}