diff --git a/pom.xml b/pom.xml
index e98586b4c..8fafc0d1b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -125,6 +125,11 @@
postgresql
9.1-901.jdbc4
+
+ org.igniterealtime.smack
+ smack-tcp
+ 4.0.0
+
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
index eb43240e8..b04ad74c4 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java
@@ -47,6 +47,7 @@ public class WhisperServerConfiguration extends Configuration {
private NexmoConfiguration nexmo;
@NotNull
+ @Valid
@JsonProperty
private GcmConfiguration gcm;
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 1f3f35ada..21e8af6b1 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -51,7 +51,10 @@ import org.whispersystems.textsecuregcm.providers.MemcacheHealthCheck;
import org.whispersystems.textsecuregcm.providers.MemcachedClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisHealthCheck;
+import org.whispersystems.textsecuregcm.push.APNSender;
+import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.PushSender;
+import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.sms.NexmoSmsSender;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
@@ -136,6 +139,19 @@ public class WhisperServerService extends Application nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());
UrlSigner urlSigner = new UrlSigner(config.getS3Configuration());
- PushSender pushSender = new PushSender(config.getGcmConfiguration(),
- config.getApnConfiguration(),
- storedMessages, pubSubManager,
- accountsManager);
+ PushSender pushSender = new PushSender(gcmSender, apnSender, websocketSender);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysControllerV1 keysControllerV1 = new KeysControllerV1(rateLimiters, keys, accountsManager, federatedClientManager);
diff --git a/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java
index 7b8300f27..e974e1d42 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java
@@ -19,8 +19,14 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
+import javax.validation.constraints.NotNull;
+
public class GcmConfiguration {
+ @NotNull
+ @JsonProperty
+ private long senderId;
+
@NotEmpty
@JsonProperty
private String apiKey;
@@ -28,4 +34,8 @@ public class GcmConfiguration {
public String getApiKey() {
return apiKey;
}
+
+ public long getSenderId() {
+ return senderId;
+ }
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java
index 4aa3b7904..978cf7666 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java
@@ -10,8 +10,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
-import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingWebsocketMessage;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
@@ -36,9 +36,9 @@ import io.dropwizard.auth.basic.BasicCredentials;
public class WebsocketController implements WebSocketListener, PubSubListener {
- private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
- private static final ObjectMapper mapper = new ObjectMapper();
- private static final Map pendingMessages = new HashMap<>();
+ private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
+ private static final ObjectMapper mapper = new ObjectMapper();
+ private static final Map pendingMessages = new HashMap<>();
private final AccountAuthenticator accountAuthenticator;
private final PubSubManager pubSubManager;
@@ -124,7 +124,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
public void onWebSocketClose(int i, String s) {
pubSubManager.unsubscribe(this.address, this);
- List remainingMessages = new LinkedList<>();
+ List remainingMessages = new LinkedList<>();
synchronized (pendingMessages) {
Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]);
@@ -137,9 +137,9 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
pendingMessages.clear();
}
- for (String remainingMessage : remainingMessages) {
+ for (PendingMessage remainingMessage : remainingMessages) {
try {
- pushSender.sendMessage(account, device, new EncryptedOutgoingMessage(remainingMessage));
+ pushSender.sendMessage(account, device, remainingMessage);
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("onWebSocketClose", e);
storedMessages.insert(account.getId(), device.getId(), remainingMessage);
@@ -147,12 +147,16 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
}
}
-
@Override
public void onPubSubMessage(PubSubMessage outgoingMessage) {
switch (outgoingMessage.getType()) {
case PubSubMessage.TYPE_DELIVER:
- handleDeliverOutgoingMessage(outgoingMessage.getContents());
+ try {
+ PendingMessage pendingMessage = mapper.readValue(outgoingMessage.getContents(), PendingMessage.class);
+ handleDeliverOutgoingMessage(pendingMessage);
+ } catch (IOException e) {
+ logger.warn("WebsocketController", "Error deserializing PendingMessage", e);
+ }
break;
case PubSubMessage.TYPE_QUERY_DB:
handleQueryDatabase();
@@ -162,7 +166,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
}
}
- private void handleDeliverOutgoingMessage(String message) {
+ private void handleDeliverOutgoingMessage(PendingMessage message) {
try {
long messageSequence;
@@ -171,7 +175,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
pendingMessages.put(messageSequence, message);
}
- WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message);
+ WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message.getEncryptedOutgoingMessage());
session.getRemote().sendStringByFuture(mapper.writeValueAsString(websocketMessage));
} catch (IOException e) {
logger.debug("Response failed", e);
@@ -192,9 +196,9 @@ public class WebsocketController implements WebSocketListener, PubSubListener {
}
private void handleQueryDatabase() {
- List messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
+ List messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
- for (String message : messages) {
+ for (PendingMessage message : messages) {
handleDeliverOutgoingMessage(message);
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java
index 89a00695f..efa33c2b2 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java
@@ -55,10 +55,6 @@ public class EncryptedOutgoingMessage {
this.serialized = Base64.encodeBytes(ciphertext);
}
- public EncryptedOutgoingMessage(String serialized) {
- this.serialized = serialized;
- }
-
public String serialize() {
return serialized;
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java
new file mode 100644
index 000000000..abbc0fa4f
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java
@@ -0,0 +1,51 @@
+package org.whispersystems.textsecuregcm.entities;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+public class PendingMessage {
+
+ @JsonProperty
+ private String sender;
+
+ @JsonProperty
+ private long messageId;
+
+ @JsonProperty
+ private String encryptedOutgoingMessage;
+
+ public PendingMessage() {}
+
+ public PendingMessage(String sender, long messageId, String encryptedOutgoingMessage) {
+ this.sender = sender;
+ this.messageId = messageId;
+ this.encryptedOutgoingMessage = encryptedOutgoingMessage;
+ }
+
+ public String getEncryptedOutgoingMessage() {
+ return encryptedOutgoingMessage;
+ }
+
+ public long getMessageId() {
+ return messageId;
+ }
+
+ public String getSender() {
+ return sender;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == null || !(other instanceof PendingMessage)) return false;
+ PendingMessage that = (PendingMessage)other;
+
+ return
+ this.sender.equals(that.sender) &&
+ this.messageId == that.messageId &&
+ this.encryptedOutgoingMessage.equals(that.encryptedOutgoingMessage);
+ }
+
+ @Override
+ public int hashCode() {
+ return this.sender.hashCode() ^ (int)this.messageId ^ this.encryptedOutgoingMessage.hashCode();
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java
index 1bc7c93b7..01f078714 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java
@@ -23,11 +23,14 @@ import com.google.common.base.Optional;
import com.notnoop.apns.APNS;
import com.notnoop.apns.ApnsService;
import com.notnoop.exceptions.NetworkIOException;
+import net.spy.memcached.MemcachedClient;
import org.bouncycastle.openssl.PEMReader;
+import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubMessage;
@@ -47,10 +50,16 @@ import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
+import java.util.Date;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
+import io.dropwizard.lifecycle.Managed;
-public class APNSender {
+public class APNSender implements Managed {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter websocketMeter = metricRegistry.meter(name(getClass(), "websocket"));
@@ -60,39 +69,53 @@ public class APNSender {
private static final String MESSAGE_BODY = "m";
- private final Optional apnService;
- private final PubSubManager pubSubManager;
- private final StoredMessages storedMessages;
+ private static final ObjectMapper mapper = new ObjectMapper();
- public APNSender(PubSubManager pubSubManager,
+ private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
+
+ private final AccountsManager accounts;
+ private final PubSubManager pubSubManager;
+ private final StoredMessages storedMessages;
+ private final MemcachedClient memcachedClient;
+
+ private final String apnCertificate;
+ private final String apnKey;
+
+ private Optional apnService;
+
+ public APNSender(AccountsManager accounts,
+ PubSubManager pubSubManager,
StoredMessages storedMessages,
+ MemcachedClient memcachedClient,
String apnCertificate, String apnKey)
- throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{
- this.pubSubManager = pubSubManager;
- this.storedMessages = storedMessages;
-
- if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) {
- byte[] keyStore = initializeKeyStore(apnCertificate, apnKey);
- this.apnService = Optional.of(APNS.newService()
- .withCert(new ByteArrayInputStream(keyStore), "insecure")
- .withSandboxDestination().build());
- } else {
- this.apnService = Optional.absent();
- }
+ this.accounts = accounts;
+ this.pubSubManager = pubSubManager;
+ this.storedMessages = storedMessages;
+ this.apnCertificate = apnCertificate;
+ this.apnKey = apnKey;
+ this.memcachedClient = memcachedClient;
}
public void sendMessage(Account account, Device device,
- String registrationId, EncryptedOutgoingMessage message)
- throws TransientPushFailureException, NotPushRegisteredException
+ String registrationId, PendingMessage message)
+ throws TransientPushFailureException
{
- if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()),
- new PubSubMessage(PubSubMessage.TYPE_DELIVER, message.serialize())))
- {
- websocketMeter.mark();
- } else {
- storedMessages.insert(account.getId(), device.getId(), message.serialize());
- sendPush(registrationId, message.serialize());
+ try {
+ String serializedPendingMessage = mapper.writeValueAsString(message);
+
+ if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()),
+ new PubSubMessage(PubSubMessage.TYPE_DELIVER,
+ serializedPendingMessage)))
+ {
+ websocketMeter.mark();
+ } else {
+ memcacheSet(registrationId, account.getNumber());
+ storedMessages.insert(account.getId(), device.getId(), message);
+ sendPush(registrationId, serializedPendingMessage);
+ }
+ } catch (IOException e) {
+ throw new TransientPushFailureException(e);
}
}
@@ -129,7 +152,7 @@ public class APNSender {
X509Certificate certificate = (X509Certificate) reader.readObject();
Certificate[] certificateChain = {certificate};
- reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemKey.getBytes())));
+ reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemKey.getBytes())));
KeyPair keyPair = (KeyPair) reader.readObject();
KeyStore keyStore = KeyStore.getInstance("pkcs12");
@@ -143,4 +166,79 @@ public class APNSender {
return baos.toByteArray();
}
+
+ @Override
+ public void start() throws Exception {
+ if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) {
+ byte[] keyStore = initializeKeyStore(apnCertificate, apnKey);
+
+ this.apnService = Optional.of(APNS.newService()
+ .withCert(new ByteArrayInputStream(keyStore), "insecure")
+ .asQueued()
+ .withSandboxDestination().build());
+
+ this.executor.scheduleAtFixedRate(new FeedbackRunnable(), 0, 1, TimeUnit.HOURS);
+ } else {
+ this.apnService = Optional.absent();
+ }
+ }
+
+ @Override
+ public void stop() throws Exception {
+ if (apnService.isPresent()) {
+ apnService.get().stop();
+ }
+ }
+
+ private void memcacheSet(String registrationId, String number) {
+ if (memcachedClient != null) {
+ memcachedClient.set("APN-" + registrationId, 60 * 60 * 24, number);
+ }
+ }
+
+ private Optional memcacheGet(String registrationId) {
+ if (memcachedClient != null) {
+ return Optional.fromNullable((String)memcachedClient.get("APN-" + registrationId));
+ } else {
+ return Optional.absent();
+ }
+ }
+
+ private class FeedbackRunnable implements Runnable {
+ private void updateAccount(Account account, String registrationId) {
+ boolean needsUpdate = false;
+
+ for (Device device : account.getDevices()) {
+ if (registrationId.equals(device.getApnId())) {
+ needsUpdate = true;
+ device.setApnId(null);
+ }
+ }
+
+ if (needsUpdate) {
+ accounts.update(account);
+ }
+ }
+
+ @Override
+ public void run() {
+ if (apnService.isPresent()) {
+ Map inactiveDevices = apnService.get().getInactiveDevices();
+
+ for (String registrationId : inactiveDevices.keySet()) {
+ Optional number = memcacheGet(registrationId);
+
+ if (number.isPresent()) {
+ Optional account = accounts.get(number.get());
+
+ if (account.isPresent()) {
+ updateAccount(account.get(), registrationId);
+ }
+ } else {
+ logger.warn("APN unregister event received for uncached ID: " + registrationId);
+ }
+ }
+ }
+ }
+ }
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java
index 9790bc7a1..c2b0b15c6 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java
@@ -1,69 +1,415 @@
-/**
- * Copyright (C) 2013 Open WhisperSystems
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
-import com.google.android.gcm.server.Constants;
-import com.google.android.gcm.server.Message;
-import com.google.android.gcm.server.Result;
-import com.google.android.gcm.server.Sender;
-import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
+import com.google.common.base.Optional;
+import org.jivesoftware.smack.ConnectionConfiguration;
+import org.jivesoftware.smack.ConnectionListener;
+import org.jivesoftware.smack.PacketListener;
+import org.jivesoftware.smack.SmackException;
+import org.jivesoftware.smack.XMPPConnection;
+import org.jivesoftware.smack.XMPPException;
+import org.jivesoftware.smack.filter.PacketTypeFilter;
+import org.jivesoftware.smack.packet.DefaultPacketExtension;
+import org.jivesoftware.smack.packet.Message;
+import org.jivesoftware.smack.packet.Packet;
+import org.jivesoftware.smack.packet.PacketExtension;
+import org.jivesoftware.smack.provider.PacketExtensionProvider;
+import org.jivesoftware.smack.provider.ProviderManager;
+import org.jivesoftware.smack.tcp.XMPPTCPConnection;
+import org.jivesoftware.smack.util.StringUtils;
+import org.json.simple.JSONObject;
+import org.json.simple.JSONValue;
+import org.json.simple.parser.ParseException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.util.Util;
+import org.xmlpull.v1.XmlPullParser;
+import javax.net.ssl.SSLSocketFactory;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
import static com.codahale.metrics.MetricRegistry.name;
+import io.dropwizard.lifecycle.Managed;
-public class GCMSender {
+public class GCMSender implements Managed, PacketListener {
+
+ private final Logger logger = LoggerFactory.getLogger(GCMSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(org.whispersystems.textsecuregcm.util.Constants.METRICS_NAME);
private final Meter success = metricRegistry.meter(name(getClass(), "sent", "success"));
private final Meter failure = metricRegistry.meter(name(getClass(), "sent", "failure"));
+ private final Meter unregistered = metricRegistry.meter(name(getClass(), "sent", "unregistered"));
- private final Sender sender;
+ private static final String GCM_SERVER = "gcm.googleapis.com";
+ private static final int GCM_PORT = 5235;
- public GCMSender(String apiKey) {
- this.sender = new Sender(apiKey);
+ private static final String GCM_ELEMENT_NAME = "gcm";
+ private static final String GCM_NAMESPACE = "google:mobile:data";
+
+ private final Map pendingMessages = new ConcurrentHashMap<>();
+
+ private final long senderId;
+ private final String apiKey;
+ private final AccountsManager accounts;
+
+ private XMPPTCPConnection connection;
+
+ public GCMSender(AccountsManager accounts, long senderId, String apiKey) {
+ this.accounts = accounts;
+ this.senderId = senderId;
+ this.apiKey = apiKey;
+
+ ProviderManager.addExtensionProvider(GCM_ELEMENT_NAME, GCM_NAMESPACE,
+ new GcmPacketExtensionProvider());
}
- public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage)
- throws NotPushRegisteredException, TransientPushFailureException
+ public void sendMessage(String destinationNumber, long destinationDeviceId,
+ String registrationId, PendingMessage message)
{
+ String messageId = "m-" + UUID.randomUUID().toString();
+ UnacknowledgedMessage unacknowledgedMessage = new UnacknowledgedMessage(destinationNumber,
+ destinationDeviceId,
+ registrationId, message);
+
+ sendMessage(messageId, unacknowledgedMessage);
+ }
+
+ public void sendMessage(String messageId, UnacknowledgedMessage message) {
try {
- Message gcmMessage = new Message.Builder().addData("type", "message")
- .addData("message", outgoingMessage.serialize())
- .build();
+ Map dataObject = new HashMap<>();
+ dataObject.put("type", "message");
+ dataObject.put("message", message.getPendingMessage().getEncryptedOutgoingMessage());
- Result result = sender.send(gcmMessage, gcmRegistrationId, 5);
+ Map messageObject = new HashMap<>();
+ messageObject.put("to", message.getRegistrationId());
+ messageObject.put("message_id", messageId);
+ messageObject.put("data", dataObject);
- if (result.getMessageId() != null) {
- success.mark();
- return result.getCanonicalRegistrationId();
- } else {
- failure.mark();
- if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) {
- throw new NotPushRegisteredException("Device no longer registered with GCM.");
- } else {
- throw new TransientPushFailureException("GCM Failed: " + result.getErrorCodeName());
- }
+ String json = JSONObject.toJSONString(messageObject);
+
+ pendingMessages.put(messageId, message);
+ connection.sendPacket(new GcmPacketExtension(json).toPacket());
+ } catch (SmackException.NotConnectedException e) {
+ logger.warn("GCMClient", "No connection", e);
+ }
+ }
+
+ @Override
+ public void start() throws Exception {
+ this.connection = connect(senderId, apiKey);
+ }
+
+ @Override
+ public void stop() throws Exception {
+ this.connection.disconnect();
+ }
+
+ @Override
+ public void processPacket(Packet packet) throws SmackException.NotConnectedException {
+ Message incomingMessage = (Message) packet;
+ GcmPacketExtension gcmPacket = (GcmPacketExtension) incomingMessage.getExtension(GCM_NAMESPACE);
+ String json = gcmPacket.getJson();
+
+ try {
+ Map jsonObject = (Map) JSONValue.parseWithException(json);
+ Object messageType = jsonObject.get("message_type");
+
+ if (messageType == null) {
+ handleUpstreamMessage(jsonObject);
+ return;
}
- } catch (IOException e) {
- throw new TransientPushFailureException(e);
+
+ switch (messageType.toString()) {
+ case "ack" : handleAckReceipt(jsonObject); break;
+ case "nack" : handleNackReceipt(jsonObject); break;
+ case "receipt" : handleDeliveryReceipt(jsonObject); break;
+ case "control" : handleControlMessage(jsonObject); break;
+ default:
+ logger.warn("Received unknown GCM message: " + messageType.toString());
+ }
+
+ } catch (ParseException e) {
+ logger.warn("GCMClient", "Received unparsable message", e);
+ } catch (Exception e) {
+ logger.warn("GCMClient", "Failed to process packet", e);
+ }
+ }
+
+ private void handleControlMessage(Map message) {
+ String controlType = (String) message.get("control_type");
+
+ if ("CONNECTION_DRAINING".equals(controlType)) {
+ logger.warn("GCM Connection is draining! Initiating reconnect...");
+ reconnect();
+ } else {
+ logger.warn("Received unknown GCM control message: " + controlType);
+ }
+ }
+
+ private void handleDeliveryReceipt(Map message) {
+ logger.warn("Got delivery receipt!");
+ }
+
+ private void handleNackReceipt(Map message) {
+ String messageId = (String) message.get("message_id");
+ String errorCode = (String) message.get("error");
+
+ if (errorCode == null) {
+ logger.warn("Null GCM error code!");
+ if (messageId != null) {
+ pendingMessages.remove(messageId);
+ }
+
+ return;
+ }
+
+ switch (errorCode) {
+ case "BAD_REGISTRATION" : handleBadRegistration(message); break;
+ case "DEVICE_UNREGISTERED" : handleBadRegistration(message); break;
+ case "INTERNAL_SERVER_ERROR" : handleServerFailure(message); break;
+ case "INVALID_JSON" : handleClientFailure(message); break;
+ case "QUOTA_EXCEEDED" : handleClientFailure(message); break;
+ case "SERVICE_UNAVAILABLE" : handleServerFailure(message); break;
+ }
+ }
+
+ private void handleAckReceipt(Map message) {
+ success.mark();
+
+ String messageId = (String) message.get("message_id");
+
+ if (messageId != null) {
+ pendingMessages.remove(messageId);
+ }
+ }
+
+ private void handleUpstreamMessage(Map message)
+ throws SmackException.NotConnectedException
+ {
+ logger.warn("Got upstream message from GCM Server!");
+ Map ack = new HashMap<>();
+ message.put("message_type", "ack");
+ message.put("to", message.get("from"));
+ message.put("message_id", message.get("message_id"));
+
+ String json = JSONValue.toJSONString(ack);
+
+ Packet request = new GcmPacketExtension(json).toPacket();
+ connection.sendPacket(request);
+ }
+
+ private void handleBadRegistration(Map message) {
+ unregistered.mark();
+
+ String messageId = (String) message.get("message_id");
+
+ if (messageId != null) {
+ UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId);
+
+ if (unacknowledgedMessage != null) {
+ Optional account = accounts.get(unacknowledgedMessage.getDestinationNumber());
+
+ if (account.isPresent()) {
+ Optional device = account.get().getDevice(unacknowledgedMessage.getDestinationDeviceId());
+
+ if (device.isPresent()) {
+ device.get().setGcmId(null);
+ accounts.update(account.get());
+ }
+ }
+
+ }
+ }
+ }
+
+ private void handleServerFailure(Map message) {
+ failure.mark();
+
+ String messageId = (String)message.get("message_id");
+
+ if (messageId != null) {
+ UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId);
+
+ if (unacknowledgedMessage != null) {
+ sendMessage(messageId, unacknowledgedMessage);
+ }
+ }
+ }
+
+ private void handleClientFailure(Map message) {
+ failure.mark();
+
+ logger.warn("Unrecoverable error: " + message.get("error"));
+ String messageId = (String)message.get("message_id");
+
+ if (messageId != null) {
+ pendingMessages.remove(messageId);
+ }
+ }
+
+ private void reconnect() {
+ try {
+ this.connection.disconnect();
+ } catch (SmackException.NotConnectedException e) {
+ logger.warn("GCMClient", "Disconnect attempt", e);
+ }
+
+ while (true) {
+ try {
+ this.connection = connect(senderId, apiKey);
+ return;
+ } catch (XMPPException | IOException | SmackException e) {
+ logger.warn("GCMClient", "Reconnecting", e);
+ Util.sleep(1000);
+ }
+ }
+ }
+
+ private XMPPTCPConnection connect(long senderId, String apiKey)
+ throws XMPPException, IOException, SmackException
+ {
+ ConnectionConfiguration config = new ConnectionConfiguration(GCM_SERVER, GCM_PORT);
+ config.setSecurityMode(ConnectionConfiguration.SecurityMode.enabled);
+ config.setReconnectionAllowed(true);
+ config.setRosterLoadedAtLogin(false);
+ config.setSendPresence(false);
+ config.setSocketFactory(SSLSocketFactory.getDefault());
+
+ XMPPTCPConnection connection = new XMPPTCPConnection(config);
+ connection.connect();
+
+ connection.addConnectionListener(new LoggingConnectionListener());
+ connection.addPacketListener(this, new PacketTypeFilter(Message.class));
+
+ connection.login(senderId + "@gcm.googleapis.com", apiKey);
+
+ return connection;
+ }
+
+ private static class GcmPacketExtensionProvider implements PacketExtensionProvider {
+ @Override
+ public PacketExtension parseExtension(XmlPullParser xmlPullParser) throws Exception {
+ String json = xmlPullParser.nextText();
+ return new GcmPacketExtension(json);
+ }
+ }
+
+ private static final class GcmPacketExtension extends DefaultPacketExtension {
+
+ private final String json;
+
+ public GcmPacketExtension(String json) {
+ super(GCM_ELEMENT_NAME, GCM_NAMESPACE);
+ this.json = json;
+ }
+
+ public String getJson() {
+ return json;
+ }
+
+ @Override
+ public String toXML() {
+ return String.format("<%s xmlns=\"%s\">%s%s>", GCM_ELEMENT_NAME, GCM_NAMESPACE,
+ StringUtils.escapeForXML(json), GCM_ELEMENT_NAME);
+ }
+
+ public Packet toPacket() {
+ Message message = new Message();
+ message.addExtension(this);
+ return message;
+ }
+ }
+
+ private class LoggingConnectionListener implements ConnectionListener {
+
+ @Override
+ public void connected(XMPPConnection xmppConnection) {
+ logger.warn("GCM XMPP Connected.");
+ }
+
+ @Override
+ public void authenticated(XMPPConnection xmppConnection) {
+ logger.warn("GCM XMPP Authenticated.");
+ }
+
+ @Override
+ public void reconnectionSuccessful() {
+ logger.warn("GCM XMPP Reconnecting..");
+ Iterator> iterator =
+ pendingMessages.entrySet().iterator();
+
+ while (iterator.hasNext()) {
+ Map.Entry entry = iterator.next();
+ iterator.remove();
+
+ sendMessage(entry.getKey(), entry.getValue());
+ }
+ }
+
+ @Override
+ public void reconnectionFailed(Exception e) {
+ logger.warn("GCM XMPP Reconnection failed!", e);
+ }
+
+ @Override
+ public void reconnectingIn(int seconds) {
+ logger.warn(String.format("GCM XMPP Reconnecting in %d secs", seconds));
+ }
+
+ @Override
+ public void connectionClosedOnError(Exception e) {
+ logger.warn("GCM XMPP Connection closed on error.");
+ }
+
+ @Override
+ public void connectionClosed() {
+ logger.warn("GCM XMPP Connection closed.");
+ }
+ }
+
+ private static class UnacknowledgedMessage {
+ private final String destinationNumber;
+ private final long destinationDeviceId;
+
+ private final String registrationId;
+ private final PendingMessage pendingMessage;
+
+ private UnacknowledgedMessage(String destinationNumber,
+ long destinationDeviceId,
+ String registrationId,
+ PendingMessage pendingMessage)
+ {
+ this.destinationNumber = destinationNumber;
+ this.destinationDeviceId = destinationDeviceId;
+ this.registrationId = registrationId;
+ this.pendingMessage = pendingMessage;
+ }
+
+ private String getRegistrationId() {
+ return registrationId;
+ }
+
+ private PendingMessage getPendingMessage() {
+ return pendingMessage;
+ }
+
+ public String getDestinationNumber() {
+ return destinationNumber;
+ }
+
+ public long getDestinationDeviceId() {
+ return destinationDeviceId;
}
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
index c9df1f9c2..fc458f3f9 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
@@ -18,44 +18,28 @@ package org.whispersystems.textsecuregcm.push;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
-import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account;
-import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
-import org.whispersystems.textsecuregcm.storage.PubSubManager;
-import org.whispersystems.textsecuregcm.storage.StoredMessages;
-
-import java.io.IOException;
-import java.security.KeyStoreException;
-import java.security.NoSuchAlgorithmException;
-import java.security.cert.CertificateException;
public class PushSender {
private final Logger logger = LoggerFactory.getLogger(PushSender.class);
- private final AccountsManager accounts;
private final GCMSender gcmSender;
private final APNSender apnSender;
private final WebsocketSender webSocketSender;
- public PushSender(GcmConfiguration gcmConfiguration,
- ApnConfiguration apnConfiguration,
- StoredMessages storedMessages,
- PubSubManager pubSubManager,
- AccountsManager accounts)
- throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
+ public PushSender(GCMSender gcmClient,
+ APNSender apnSender,
+ WebsocketSender websocketSender)
{
- this.accounts = accounts;
- this.webSocketSender = new WebsocketSender(storedMessages, pubSubManager);
- this.gcmSender = new GCMSender(gcmConfiguration.getApiKey());
- this.apnSender = new APNSender(pubSubManager, storedMessages,
- apnConfiguration.getCertificate(),
- apnConfiguration.getKey());
+ this.gcmSender = gcmClient;
+ this.apnSender = apnSender;
+ this.webSocketSender = websocketSender;
}
public void sendMessage(Account account, Device device, MessageProtos.OutgoingMessageSignal message)
@@ -64,60 +48,39 @@ public class PushSender {
try {
String signalingKey = device.getSignalingKey();
EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, signalingKey);
+ PendingMessage pendingMessage = new PendingMessage(message.getSource(), message.getTimestamp(), encryptedMessage.serialize());
- sendMessage(account, device, encryptedMessage);
+ sendMessage(account, device, pendingMessage);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
}
- public void sendMessage(Account account, Device device, EncryptedOutgoingMessage message)
+ public void sendMessage(Account account, Device device, PendingMessage pendingMessage)
throws NotPushRegisteredException, TransientPushFailureException
{
- if (device.getGcmId() != null) sendGcmMessage(account, device, message);
- else if (device.getApnId() != null) sendApnMessage(account, device, message);
- else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message);
+ if (device.getGcmId() != null) sendGcmMessage(account, device, pendingMessage);
+ else if (device.getApnId() != null) sendApnMessage(account, device, pendingMessage);
+ else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, pendingMessage);
else throw new NotPushRegisteredException("No delivery possible!");
}
- private void sendGcmMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
- throws NotPushRegisteredException, TransientPushFailureException
- {
- try {
- String canonicalId = gcmSender.sendMessage(device.getGcmId(), outgoingMessage);
+ private void sendGcmMessage(Account account, Device device, PendingMessage pendingMessage) {
+ String number = account.getNumber();
+ long deviceId = device.getId();
+ String registrationId = device.getGcmId();
- if (canonicalId != null) {
- device.setGcmId(canonicalId);
- accounts.update(account);
- }
-
- } catch (NotPushRegisteredException e) {
- logger.debug("No Such User", e);
- device.setGcmId(null);
- accounts.update(account);
- throw new NotPushRegisteredException(e);
- }
+ gcmSender.sendMessage(number, deviceId, registrationId, pendingMessage);
}
- private void sendApnMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
- throws TransientPushFailureException, NotPushRegisteredException
+ private void sendApnMessage(Account account, Device device, PendingMessage outgoingMessage)
+ throws TransientPushFailureException
{
- try {
- apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage);
- } catch (NotPushRegisteredException e) {
- device.setApnId(null);
- accounts.update(account);
- throw new NotPushRegisteredException(e);
- }
+ apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage);
}
- private void sendWebSocketMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
- throws NotPushRegisteredException
+ private void sendWebSocketMessage(Account account, Device device, PendingMessage outgoingMessage)
{
- try {
- webSocketSender.sendMessage(account, device, outgoingMessage);
- } catch (CryptoEncodingException e) {
- throw new NotPushRegisteredException(e);
- }
+ webSocketSender.sendMessage(account, device, outgoingMessage);
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
index c549da4b3..612772e94 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
@@ -19,8 +19,12 @@ package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
-import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
-import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.controllers.WebsocketController;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
@@ -29,16 +33,18 @@ import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
-import java.util.List;
-
import static com.codahale.metrics.MetricRegistry.name;
public class WebsocketSender {
+ private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
+
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online"));
private final Meter offlineMeter = metricRegistry.meter(name(getClass(), "offline"));
+ private static final ObjectMapper mapper = new ObjectMapper();
+
private final StoredMessages storedMessages;
private final PubSubManager pubSubManager;
@@ -47,22 +53,21 @@ public class WebsocketSender {
this.pubSubManager = pubSubManager;
}
- public void sendMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
- throws CryptoEncodingException
- {
- sendMessage(account, device, outgoingMessage.serialize());
- }
+ public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
+ try {
+ String serialized = mapper.writeValueAsString(pendingMessage);
+ WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
+ PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
- private void sendMessage(Account account, Device device, String serializedMessage) {
- WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
- PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage);
-
- if (pubSubManager.publish(address, pubSubMessage)) {
- onlineMeter.mark();
- } else {
- offlineMeter.mark();
- storedMessages.insert(account.getId(), device.getId(), serializedMessage);
- pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
+ if (pubSubManager.publish(address, pubSubMessage)) {
+ onlineMeter.mark();
+ } else {
+ offlineMeter.mark();
+ storedMessages.insert(account.getId(), device.getId(), pendingMessage);
+ pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
+ }
+ } catch (JsonProcessingException e) {
+ logger.warn("WebsocketSender", "Unable to serialize json", e);
}
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
index b954e4cda..3a52da634 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
@@ -19,9 +19,15 @@ package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.util.Constants;
+import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
@@ -31,9 +37,13 @@ import redis.clients.jedis.JedisPool;
public class StoredMessages {
+ private static final Logger logger = LoggerFactory.getLogger(StoredMessages.class);
+
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(getClass(), "queue_size"));
+
+ private static final ObjectMapper mapper = new ObjectMapper();
private static final String QUEUE_PREFIX = "msgs";
private final JedisPool jedisPool;
@@ -42,34 +52,42 @@ public class StoredMessages {
this.jedisPool = jedisPool;
}
- public void insert(long accountId, long deviceId, String message) {
+ public void insert(long accountId, long deviceId, PendingMessage message) {
Jedis jedis = null;
try {
jedis = jedisPool.getResource();
- long queueSize = jedis.lpush(getKey(accountId, deviceId), message);
+ String serializedMessage = mapper.writeValueAsString(message);
+ long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage);
queueSizeHistogram.update(queueSize);
if (queueSize > 1000) {
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
}
+
+ } catch (JsonProcessingException e) {
+ logger.warn("StoredMessages", "Unable to store correctly", e);
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
- public List getMessagesForDevice(long accountId, long deviceId) {
- List messages = new LinkedList<>();
- Jedis jedis = null;
+ public List getMessagesForDevice(long accountId, long deviceId) {
+ List messages = new LinkedList<>();
+ Jedis jedis = null;
try {
jedis = jedisPool.getResource();
String message;
while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) {
- messages.add(message);
+ try {
+ messages.add(mapper.readValue(message, PendingMessage.class));
+ } catch (IOException e) {
+ logger.warn("StoredMessages", "Not a valid PendingMessage", e);
+ }
}
return messages;
diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java
index 8b9b960f6..b1e5c256b 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java
@@ -83,4 +83,40 @@ public class Util {
return result;
}
+
+ public static byte[][] split(byte[] input, int firstLength, int secondLength) {
+ byte[][] parts = new byte[2][];
+
+ parts[0] = new byte[firstLength];
+ System.arraycopy(input, 0, parts[0], 0, firstLength);
+
+ parts[1] = new byte[secondLength];
+ System.arraycopy(input, firstLength, parts[1], 0, secondLength);
+
+ return parts;
+ }
+
+ public static byte[][] split(byte[] input, int firstLength, int secondLength, int thirdLength, int fourthLength) {
+ byte[][] parts = new byte[4][];
+
+ parts[0] = new byte[firstLength];
+ System.arraycopy(input, 0, parts[0], 0, firstLength);
+
+ parts[1] = new byte[secondLength];
+ System.arraycopy(input, firstLength, parts[1], 0, secondLength);
+
+ parts[2] = new byte[thirdLength];
+ System.arraycopy(input, firstLength + secondLength, parts[2], 0, thirdLength);
+
+ parts[3] = new byte[fourthLength];
+ System.arraycopy(input, firstLength + secondLength + thirdLength, parts[3], 0, fourthLength);
+
+ return parts;
+ }
+
+ public static void sleep(int i) {
+ try {
+ Thread.sleep(i);
+ } catch (InterruptedException ie) {}
+ }
}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java
index 296ef0331..f1006c2fb 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java
@@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.controllers.WebsocketController;
import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -83,10 +84,10 @@ public class WebsocketControllerTest {
public void testOpen() throws Exception {
RemoteEndpoint remote = mock(RemoteEndpoint.class);
- List outgoingMessages = new LinkedList() {{
- add("first");
- add("second");
- add("third");
+ List outgoingMessages = new LinkedList() {{
+ add(new PendingMessage("sender1", 1111, "first"));
+ add(new PendingMessage("sender1", 2222, "second"));
+ add(new PendingMessage("sender2", 3333, "third"));
}};
when(device.getId()).thenReturn(2L);
@@ -103,7 +104,8 @@ public class WebsocketControllerTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
- when(storedMessages.getMessagesForDevice(account.getId(), device.getId())).thenReturn(outgoingMessages);
+ when(storedMessages.getMessagesForDevice(account.getId(), device.getId()))
+ .thenReturn(outgoingMessages);
WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, pushSender, storedMessages, pubSubManager);
WebsocketController controller = (WebsocketController) factory.createWebSocket(null, null);
@@ -116,12 +118,13 @@ public class WebsocketControllerTest {
controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));
controller.onWebSocketClose(1000, "Closed");
- List pending = new LinkedList() {{
- add("first");
- add("third");
+ List pending = new LinkedList() {{
+ add(new PendingMessage("sender1", 1111, "first"));
+ add(new PendingMessage("sender2", 3333, "third"));
}};
- verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(EncryptedOutgoingMessage.class));
+
+ verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(PendingMessage.class));
}
}