diff --git a/pom.xml b/pom.xml index e98586b4c..8fafc0d1b 100644 --- a/pom.xml +++ b/pom.xml @@ -125,6 +125,11 @@ postgresql 9.1-901.jdbc4 + + org.igniterealtime.smack + smack-tcp + 4.0.0 + diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index eb43240e8..b04ad74c4 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -47,6 +47,7 @@ public class WhisperServerConfiguration extends Configuration { private NexmoConfiguration nexmo; @NotNull + @Valid @JsonProperty private GcmConfiguration gcm; diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 1f3f35ada..21e8af6b1 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -51,7 +51,10 @@ import org.whispersystems.textsecuregcm.providers.MemcacheHealthCheck; import org.whispersystems.textsecuregcm.providers.MemcachedClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisHealthCheck; +import org.whispersystems.textsecuregcm.push.APNSender; +import org.whispersystems.textsecuregcm.push.GCMSender; import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.push.WebsocketSender; import org.whispersystems.textsecuregcm.sms.NexmoSmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; @@ -136,6 +139,19 @@ public class WhisperServerService extends Application nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration()); SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational()); UrlSigner urlSigner = new UrlSigner(config.getS3Configuration()); - PushSender pushSender = new PushSender(config.getGcmConfiguration(), - config.getApnConfiguration(), - storedMessages, pubSubManager, - accountsManager); + PushSender pushSender = new PushSender(gcmSender, apnSender, websocketSender); AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner); KeysControllerV1 keysControllerV1 = new KeysControllerV1(rateLimiters, keys, accountsManager, federatedClientManager); diff --git a/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java index 7b8300f27..e974e1d42 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java +++ b/src/main/java/org/whispersystems/textsecuregcm/configuration/GcmConfiguration.java @@ -19,8 +19,14 @@ package org.whispersystems.textsecuregcm.configuration; import com.fasterxml.jackson.annotation.JsonProperty; import org.hibernate.validator.constraints.NotEmpty; +import javax.validation.constraints.NotNull; + public class GcmConfiguration { + @NotNull + @JsonProperty + private long senderId; + @NotEmpty @JsonProperty private String apiKey; @@ -28,4 +34,8 @@ public class GcmConfiguration { public String getApiKey() { return apiKey; } + + public long getSenderId() { + return senderId; + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java index 4aa3b7904..978cf7666 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java @@ -10,8 +10,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage; -import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.IncomingWebsocketMessage; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.TransientPushFailureException; @@ -36,9 +36,9 @@ import io.dropwizard.auth.basic.BasicCredentials; public class WebsocketController implements WebSocketListener, PubSubListener { - private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class); - private static final ObjectMapper mapper = new ObjectMapper(); - private static final Map pendingMessages = new HashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class); + private static final ObjectMapper mapper = new ObjectMapper(); + private static final Map pendingMessages = new HashMap<>(); private final AccountAuthenticator accountAuthenticator; private final PubSubManager pubSubManager; @@ -124,7 +124,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener { public void onWebSocketClose(int i, String s) { pubSubManager.unsubscribe(this.address, this); - List remainingMessages = new LinkedList<>(); + List remainingMessages = new LinkedList<>(); synchronized (pendingMessages) { Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]); @@ -137,9 +137,9 @@ public class WebsocketController implements WebSocketListener, PubSubListener { pendingMessages.clear(); } - for (String remainingMessage : remainingMessages) { + for (PendingMessage remainingMessage : remainingMessages) { try { - pushSender.sendMessage(account, device, new EncryptedOutgoingMessage(remainingMessage)); + pushSender.sendMessage(account, device, remainingMessage); } catch (NotPushRegisteredException | TransientPushFailureException e) { logger.warn("onWebSocketClose", e); storedMessages.insert(account.getId(), device.getId(), remainingMessage); @@ -147,12 +147,16 @@ public class WebsocketController implements WebSocketListener, PubSubListener { } } - @Override public void onPubSubMessage(PubSubMessage outgoingMessage) { switch (outgoingMessage.getType()) { case PubSubMessage.TYPE_DELIVER: - handleDeliverOutgoingMessage(outgoingMessage.getContents()); + try { + PendingMessage pendingMessage = mapper.readValue(outgoingMessage.getContents(), PendingMessage.class); + handleDeliverOutgoingMessage(pendingMessage); + } catch (IOException e) { + logger.warn("WebsocketController", "Error deserializing PendingMessage", e); + } break; case PubSubMessage.TYPE_QUERY_DB: handleQueryDatabase(); @@ -162,7 +166,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener { } } - private void handleDeliverOutgoingMessage(String message) { + private void handleDeliverOutgoingMessage(PendingMessage message) { try { long messageSequence; @@ -171,7 +175,7 @@ public class WebsocketController implements WebSocketListener, PubSubListener { pendingMessages.put(messageSequence, message); } - WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message); + WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message.getEncryptedOutgoingMessage()); session.getRemote().sendStringByFuture(mapper.writeValueAsString(websocketMessage)); } catch (IOException e) { logger.debug("Response failed", e); @@ -192,9 +196,9 @@ public class WebsocketController implements WebSocketListener, PubSubListener { } private void handleQueryDatabase() { - List messages = storedMessages.getMessagesForDevice(account.getId(), device.getId()); + List messages = storedMessages.getMessagesForDevice(account.getId(), device.getId()); - for (String message : messages) { + for (PendingMessage message : messages) { handleDeliverOutgoingMessage(message); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java index 89a00695f..efa33c2b2 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedOutgoingMessage.java @@ -55,10 +55,6 @@ public class EncryptedOutgoingMessage { this.serialized = Base64.encodeBytes(ciphertext); } - public EncryptedOutgoingMessage(String serialized) { - this.serialized = serialized; - } - public String serialize() { return serialized; } diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java new file mode 100644 index 000000000..abbc0fa4f --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/PendingMessage.java @@ -0,0 +1,51 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class PendingMessage { + + @JsonProperty + private String sender; + + @JsonProperty + private long messageId; + + @JsonProperty + private String encryptedOutgoingMessage; + + public PendingMessage() {} + + public PendingMessage(String sender, long messageId, String encryptedOutgoingMessage) { + this.sender = sender; + this.messageId = messageId; + this.encryptedOutgoingMessage = encryptedOutgoingMessage; + } + + public String getEncryptedOutgoingMessage() { + return encryptedOutgoingMessage; + } + + public long getMessageId() { + return messageId; + } + + public String getSender() { + return sender; + } + + @Override + public boolean equals(Object other) { + if (other == null || !(other instanceof PendingMessage)) return false; + PendingMessage that = (PendingMessage)other; + + return + this.sender.equals(that.sender) && + this.messageId == that.messageId && + this.encryptedOutgoingMessage.equals(that.encryptedOutgoingMessage); + } + + @Override + public int hashCode() { + return this.sender.hashCode() ^ (int)this.messageId ^ this.encryptedOutgoingMessage.hashCode(); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java index 1bc7c93b7..01f078714 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java @@ -23,11 +23,14 @@ import com.google.common.base.Optional; import com.notnoop.apns.APNS; import com.notnoop.apns.ApnsService; import com.notnoop.exceptions.NetworkIOException; +import net.spy.memcached.MemcachedClient; import org.bouncycastle.openssl.PEMReader; +import org.codehaus.jackson.map.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubMessage; @@ -47,10 +50,16 @@ import java.security.NoSuchAlgorithmException; import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.Date; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import static com.codahale.metrics.MetricRegistry.name; +import io.dropwizard.lifecycle.Managed; -public class APNSender { +public class APNSender implements Managed { private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Meter websocketMeter = metricRegistry.meter(name(getClass(), "websocket")); @@ -60,39 +69,53 @@ public class APNSender { private static final String MESSAGE_BODY = "m"; - private final Optional apnService; - private final PubSubManager pubSubManager; - private final StoredMessages storedMessages; + private static final ObjectMapper mapper = new ObjectMapper(); - public APNSender(PubSubManager pubSubManager, + private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + + private final AccountsManager accounts; + private final PubSubManager pubSubManager; + private final StoredMessages storedMessages; + private final MemcachedClient memcachedClient; + + private final String apnCertificate; + private final String apnKey; + + private Optional apnService; + + public APNSender(AccountsManager accounts, + PubSubManager pubSubManager, StoredMessages storedMessages, + MemcachedClient memcachedClient, String apnCertificate, String apnKey) - throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException { - this.pubSubManager = pubSubManager; - this.storedMessages = storedMessages; - - if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) { - byte[] keyStore = initializeKeyStore(apnCertificate, apnKey); - this.apnService = Optional.of(APNS.newService() - .withCert(new ByteArrayInputStream(keyStore), "insecure") - .withSandboxDestination().build()); - } else { - this.apnService = Optional.absent(); - } + this.accounts = accounts; + this.pubSubManager = pubSubManager; + this.storedMessages = storedMessages; + this.apnCertificate = apnCertificate; + this.apnKey = apnKey; + this.memcachedClient = memcachedClient; } public void sendMessage(Account account, Device device, - String registrationId, EncryptedOutgoingMessage message) - throws TransientPushFailureException, NotPushRegisteredException + String registrationId, PendingMessage message) + throws TransientPushFailureException { - if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()), - new PubSubMessage(PubSubMessage.TYPE_DELIVER, message.serialize()))) - { - websocketMeter.mark(); - } else { - storedMessages.insert(account.getId(), device.getId(), message.serialize()); - sendPush(registrationId, message.serialize()); + try { + String serializedPendingMessage = mapper.writeValueAsString(message); + + if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()), + new PubSubMessage(PubSubMessage.TYPE_DELIVER, + serializedPendingMessage))) + { + websocketMeter.mark(); + } else { + memcacheSet(registrationId, account.getNumber()); + storedMessages.insert(account.getId(), device.getId(), message); + sendPush(registrationId, serializedPendingMessage); + } + } catch (IOException e) { + throw new TransientPushFailureException(e); } } @@ -129,7 +152,7 @@ public class APNSender { X509Certificate certificate = (X509Certificate) reader.readObject(); Certificate[] certificateChain = {certificate}; - reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemKey.getBytes()))); + reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemKey.getBytes()))); KeyPair keyPair = (KeyPair) reader.readObject(); KeyStore keyStore = KeyStore.getInstance("pkcs12"); @@ -143,4 +166,79 @@ public class APNSender { return baos.toByteArray(); } + + @Override + public void start() throws Exception { + if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) { + byte[] keyStore = initializeKeyStore(apnCertificate, apnKey); + + this.apnService = Optional.of(APNS.newService() + .withCert(new ByteArrayInputStream(keyStore), "insecure") + .asQueued() + .withSandboxDestination().build()); + + this.executor.scheduleAtFixedRate(new FeedbackRunnable(), 0, 1, TimeUnit.HOURS); + } else { + this.apnService = Optional.absent(); + } + } + + @Override + public void stop() throws Exception { + if (apnService.isPresent()) { + apnService.get().stop(); + } + } + + private void memcacheSet(String registrationId, String number) { + if (memcachedClient != null) { + memcachedClient.set("APN-" + registrationId, 60 * 60 * 24, number); + } + } + + private Optional memcacheGet(String registrationId) { + if (memcachedClient != null) { + return Optional.fromNullable((String)memcachedClient.get("APN-" + registrationId)); + } else { + return Optional.absent(); + } + } + + private class FeedbackRunnable implements Runnable { + private void updateAccount(Account account, String registrationId) { + boolean needsUpdate = false; + + for (Device device : account.getDevices()) { + if (registrationId.equals(device.getApnId())) { + needsUpdate = true; + device.setApnId(null); + } + } + + if (needsUpdate) { + accounts.update(account); + } + } + + @Override + public void run() { + if (apnService.isPresent()) { + Map inactiveDevices = apnService.get().getInactiveDevices(); + + for (String registrationId : inactiveDevices.keySet()) { + Optional number = memcacheGet(registrationId); + + if (number.isPresent()) { + Optional account = accounts.get(number.get()); + + if (account.isPresent()) { + updateAccount(account.get(), registrationId); + } + } else { + logger.warn("APN unregister event received for uncached ID: " + registrationId); + } + } + } + } + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java index 9790bc7a1..c2b0b15c6 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java @@ -1,69 +1,415 @@ -/** - * Copyright (C) 2013 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ package org.whispersystems.textsecuregcm.push; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; -import com.google.android.gcm.server.Constants; -import com.google.android.gcm.server.Message; -import com.google.android.gcm.server.Result; -import com.google.android.gcm.server.Sender; -import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; +import com.google.common.base.Optional; +import org.jivesoftware.smack.ConnectionConfiguration; +import org.jivesoftware.smack.ConnectionListener; +import org.jivesoftware.smack.PacketListener; +import org.jivesoftware.smack.SmackException; +import org.jivesoftware.smack.XMPPConnection; +import org.jivesoftware.smack.XMPPException; +import org.jivesoftware.smack.filter.PacketTypeFilter; +import org.jivesoftware.smack.packet.DefaultPacketExtension; +import org.jivesoftware.smack.packet.Message; +import org.jivesoftware.smack.packet.Packet; +import org.jivesoftware.smack.packet.PacketExtension; +import org.jivesoftware.smack.provider.PacketExtensionProvider; +import org.jivesoftware.smack.provider.ProviderManager; +import org.jivesoftware.smack.tcp.XMPPTCPConnection; +import org.jivesoftware.smack.util.StringUtils; +import org.json.simple.JSONObject; +import org.json.simple.JSONValue; +import org.json.simple.parser.ParseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.PendingMessage; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Util; +import org.xmlpull.v1.XmlPullParser; +import javax.net.ssl.SSLSocketFactory; import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import static com.codahale.metrics.MetricRegistry.name; +import io.dropwizard.lifecycle.Managed; -public class GCMSender { +public class GCMSender implements Managed, PacketListener { + + private final Logger logger = LoggerFactory.getLogger(GCMSender.class); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(org.whispersystems.textsecuregcm.util.Constants.METRICS_NAME); private final Meter success = metricRegistry.meter(name(getClass(), "sent", "success")); private final Meter failure = metricRegistry.meter(name(getClass(), "sent", "failure")); + private final Meter unregistered = metricRegistry.meter(name(getClass(), "sent", "unregistered")); - private final Sender sender; + private static final String GCM_SERVER = "gcm.googleapis.com"; + private static final int GCM_PORT = 5235; - public GCMSender(String apiKey) { - this.sender = new Sender(apiKey); + private static final String GCM_ELEMENT_NAME = "gcm"; + private static final String GCM_NAMESPACE = "google:mobile:data"; + + private final Map pendingMessages = new ConcurrentHashMap<>(); + + private final long senderId; + private final String apiKey; + private final AccountsManager accounts; + + private XMPPTCPConnection connection; + + public GCMSender(AccountsManager accounts, long senderId, String apiKey) { + this.accounts = accounts; + this.senderId = senderId; + this.apiKey = apiKey; + + ProviderManager.addExtensionProvider(GCM_ELEMENT_NAME, GCM_NAMESPACE, + new GcmPacketExtensionProvider()); } - public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage) - throws NotPushRegisteredException, TransientPushFailureException + public void sendMessage(String destinationNumber, long destinationDeviceId, + String registrationId, PendingMessage message) { + String messageId = "m-" + UUID.randomUUID().toString(); + UnacknowledgedMessage unacknowledgedMessage = new UnacknowledgedMessage(destinationNumber, + destinationDeviceId, + registrationId, message); + + sendMessage(messageId, unacknowledgedMessage); + } + + public void sendMessage(String messageId, UnacknowledgedMessage message) { try { - Message gcmMessage = new Message.Builder().addData("type", "message") - .addData("message", outgoingMessage.serialize()) - .build(); + Map dataObject = new HashMap<>(); + dataObject.put("type", "message"); + dataObject.put("message", message.getPendingMessage().getEncryptedOutgoingMessage()); - Result result = sender.send(gcmMessage, gcmRegistrationId, 5); + Map messageObject = new HashMap<>(); + messageObject.put("to", message.getRegistrationId()); + messageObject.put("message_id", messageId); + messageObject.put("data", dataObject); - if (result.getMessageId() != null) { - success.mark(); - return result.getCanonicalRegistrationId(); - } else { - failure.mark(); - if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) { - throw new NotPushRegisteredException("Device no longer registered with GCM."); - } else { - throw new TransientPushFailureException("GCM Failed: " + result.getErrorCodeName()); - } + String json = JSONObject.toJSONString(messageObject); + + pendingMessages.put(messageId, message); + connection.sendPacket(new GcmPacketExtension(json).toPacket()); + } catch (SmackException.NotConnectedException e) { + logger.warn("GCMClient", "No connection", e); + } + } + + @Override + public void start() throws Exception { + this.connection = connect(senderId, apiKey); + } + + @Override + public void stop() throws Exception { + this.connection.disconnect(); + } + + @Override + public void processPacket(Packet packet) throws SmackException.NotConnectedException { + Message incomingMessage = (Message) packet; + GcmPacketExtension gcmPacket = (GcmPacketExtension) incomingMessage.getExtension(GCM_NAMESPACE); + String json = gcmPacket.getJson(); + + try { + Map jsonObject = (Map) JSONValue.parseWithException(json); + Object messageType = jsonObject.get("message_type"); + + if (messageType == null) { + handleUpstreamMessage(jsonObject); + return; } - } catch (IOException e) { - throw new TransientPushFailureException(e); + + switch (messageType.toString()) { + case "ack" : handleAckReceipt(jsonObject); break; + case "nack" : handleNackReceipt(jsonObject); break; + case "receipt" : handleDeliveryReceipt(jsonObject); break; + case "control" : handleControlMessage(jsonObject); break; + default: + logger.warn("Received unknown GCM message: " + messageType.toString()); + } + + } catch (ParseException e) { + logger.warn("GCMClient", "Received unparsable message", e); + } catch (Exception e) { + logger.warn("GCMClient", "Failed to process packet", e); + } + } + + private void handleControlMessage(Map message) { + String controlType = (String) message.get("control_type"); + + if ("CONNECTION_DRAINING".equals(controlType)) { + logger.warn("GCM Connection is draining! Initiating reconnect..."); + reconnect(); + } else { + logger.warn("Received unknown GCM control message: " + controlType); + } + } + + private void handleDeliveryReceipt(Map message) { + logger.warn("Got delivery receipt!"); + } + + private void handleNackReceipt(Map message) { + String messageId = (String) message.get("message_id"); + String errorCode = (String) message.get("error"); + + if (errorCode == null) { + logger.warn("Null GCM error code!"); + if (messageId != null) { + pendingMessages.remove(messageId); + } + + return; + } + + switch (errorCode) { + case "BAD_REGISTRATION" : handleBadRegistration(message); break; + case "DEVICE_UNREGISTERED" : handleBadRegistration(message); break; + case "INTERNAL_SERVER_ERROR" : handleServerFailure(message); break; + case "INVALID_JSON" : handleClientFailure(message); break; + case "QUOTA_EXCEEDED" : handleClientFailure(message); break; + case "SERVICE_UNAVAILABLE" : handleServerFailure(message); break; + } + } + + private void handleAckReceipt(Map message) { + success.mark(); + + String messageId = (String) message.get("message_id"); + + if (messageId != null) { + pendingMessages.remove(messageId); + } + } + + private void handleUpstreamMessage(Map message) + throws SmackException.NotConnectedException + { + logger.warn("Got upstream message from GCM Server!"); + Map ack = new HashMap<>(); + message.put("message_type", "ack"); + message.put("to", message.get("from")); + message.put("message_id", message.get("message_id")); + + String json = JSONValue.toJSONString(ack); + + Packet request = new GcmPacketExtension(json).toPacket(); + connection.sendPacket(request); + } + + private void handleBadRegistration(Map message) { + unregistered.mark(); + + String messageId = (String) message.get("message_id"); + + if (messageId != null) { + UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId); + + if (unacknowledgedMessage != null) { + Optional account = accounts.get(unacknowledgedMessage.getDestinationNumber()); + + if (account.isPresent()) { + Optional device = account.get().getDevice(unacknowledgedMessage.getDestinationDeviceId()); + + if (device.isPresent()) { + device.get().setGcmId(null); + accounts.update(account.get()); + } + } + + } + } + } + + private void handleServerFailure(Map message) { + failure.mark(); + + String messageId = (String)message.get("message_id"); + + if (messageId != null) { + UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId); + + if (unacknowledgedMessage != null) { + sendMessage(messageId, unacknowledgedMessage); + } + } + } + + private void handleClientFailure(Map message) { + failure.mark(); + + logger.warn("Unrecoverable error: " + message.get("error")); + String messageId = (String)message.get("message_id"); + + if (messageId != null) { + pendingMessages.remove(messageId); + } + } + + private void reconnect() { + try { + this.connection.disconnect(); + } catch (SmackException.NotConnectedException e) { + logger.warn("GCMClient", "Disconnect attempt", e); + } + + while (true) { + try { + this.connection = connect(senderId, apiKey); + return; + } catch (XMPPException | IOException | SmackException e) { + logger.warn("GCMClient", "Reconnecting", e); + Util.sleep(1000); + } + } + } + + private XMPPTCPConnection connect(long senderId, String apiKey) + throws XMPPException, IOException, SmackException + { + ConnectionConfiguration config = new ConnectionConfiguration(GCM_SERVER, GCM_PORT); + config.setSecurityMode(ConnectionConfiguration.SecurityMode.enabled); + config.setReconnectionAllowed(true); + config.setRosterLoadedAtLogin(false); + config.setSendPresence(false); + config.setSocketFactory(SSLSocketFactory.getDefault()); + + XMPPTCPConnection connection = new XMPPTCPConnection(config); + connection.connect(); + + connection.addConnectionListener(new LoggingConnectionListener()); + connection.addPacketListener(this, new PacketTypeFilter(Message.class)); + + connection.login(senderId + "@gcm.googleapis.com", apiKey); + + return connection; + } + + private static class GcmPacketExtensionProvider implements PacketExtensionProvider { + @Override + public PacketExtension parseExtension(XmlPullParser xmlPullParser) throws Exception { + String json = xmlPullParser.nextText(); + return new GcmPacketExtension(json); + } + } + + private static final class GcmPacketExtension extends DefaultPacketExtension { + + private final String json; + + public GcmPacketExtension(String json) { + super(GCM_ELEMENT_NAME, GCM_NAMESPACE); + this.json = json; + } + + public String getJson() { + return json; + } + + @Override + public String toXML() { + return String.format("<%s xmlns=\"%s\">%s", GCM_ELEMENT_NAME, GCM_NAMESPACE, + StringUtils.escapeForXML(json), GCM_ELEMENT_NAME); + } + + public Packet toPacket() { + Message message = new Message(); + message.addExtension(this); + return message; + } + } + + private class LoggingConnectionListener implements ConnectionListener { + + @Override + public void connected(XMPPConnection xmppConnection) { + logger.warn("GCM XMPP Connected."); + } + + @Override + public void authenticated(XMPPConnection xmppConnection) { + logger.warn("GCM XMPP Authenticated."); + } + + @Override + public void reconnectionSuccessful() { + logger.warn("GCM XMPP Reconnecting.."); + Iterator> iterator = + pendingMessages.entrySet().iterator(); + + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + iterator.remove(); + + sendMessage(entry.getKey(), entry.getValue()); + } + } + + @Override + public void reconnectionFailed(Exception e) { + logger.warn("GCM XMPP Reconnection failed!", e); + } + + @Override + public void reconnectingIn(int seconds) { + logger.warn(String.format("GCM XMPP Reconnecting in %d secs", seconds)); + } + + @Override + public void connectionClosedOnError(Exception e) { + logger.warn("GCM XMPP Connection closed on error."); + } + + @Override + public void connectionClosed() { + logger.warn("GCM XMPP Connection closed."); + } + } + + private static class UnacknowledgedMessage { + private final String destinationNumber; + private final long destinationDeviceId; + + private final String registrationId; + private final PendingMessage pendingMessage; + + private UnacknowledgedMessage(String destinationNumber, + long destinationDeviceId, + String registrationId, + PendingMessage pendingMessage) + { + this.destinationNumber = destinationNumber; + this.destinationDeviceId = destinationDeviceId; + this.registrationId = registrationId; + this.pendingMessage = pendingMessage; + } + + private String getRegistrationId() { + return registrationId; + } + + private PendingMessage getPendingMessage() { + return pendingMessage; + } + + public String getDestinationNumber() { + return destinationNumber; + } + + public long getDestinationDeviceId() { + return destinationDeviceId; } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index c9df1f9c2..fc458f3f9 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -18,44 +18,28 @@ package org.whispersystems.textsecuregcm.push; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; -import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.StoredMessages; - -import java.io.IOException; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; public class PushSender { private final Logger logger = LoggerFactory.getLogger(PushSender.class); - private final AccountsManager accounts; private final GCMSender gcmSender; private final APNSender apnSender; private final WebsocketSender webSocketSender; - public PushSender(GcmConfiguration gcmConfiguration, - ApnConfiguration apnConfiguration, - StoredMessages storedMessages, - PubSubManager pubSubManager, - AccountsManager accounts) - throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException + public PushSender(GCMSender gcmClient, + APNSender apnSender, + WebsocketSender websocketSender) { - this.accounts = accounts; - this.webSocketSender = new WebsocketSender(storedMessages, pubSubManager); - this.gcmSender = new GCMSender(gcmConfiguration.getApiKey()); - this.apnSender = new APNSender(pubSubManager, storedMessages, - apnConfiguration.getCertificate(), - apnConfiguration.getKey()); + this.gcmSender = gcmClient; + this.apnSender = apnSender; + this.webSocketSender = websocketSender; } public void sendMessage(Account account, Device device, MessageProtos.OutgoingMessageSignal message) @@ -64,60 +48,39 @@ public class PushSender { try { String signalingKey = device.getSignalingKey(); EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, signalingKey); + PendingMessage pendingMessage = new PendingMessage(message.getSource(), message.getTimestamp(), encryptedMessage.serialize()); - sendMessage(account, device, encryptedMessage); + sendMessage(account, device, pendingMessage); } catch (CryptoEncodingException e) { throw new NotPushRegisteredException(e); } } - public void sendMessage(Account account, Device device, EncryptedOutgoingMessage message) + public void sendMessage(Account account, Device device, PendingMessage pendingMessage) throws NotPushRegisteredException, TransientPushFailureException { - if (device.getGcmId() != null) sendGcmMessage(account, device, message); - else if (device.getApnId() != null) sendApnMessage(account, device, message); - else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); + if (device.getGcmId() != null) sendGcmMessage(account, device, pendingMessage); + else if (device.getApnId() != null) sendApnMessage(account, device, pendingMessage); + else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, pendingMessage); else throw new NotPushRegisteredException("No delivery possible!"); } - private void sendGcmMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) - throws NotPushRegisteredException, TransientPushFailureException - { - try { - String canonicalId = gcmSender.sendMessage(device.getGcmId(), outgoingMessage); + private void sendGcmMessage(Account account, Device device, PendingMessage pendingMessage) { + String number = account.getNumber(); + long deviceId = device.getId(); + String registrationId = device.getGcmId(); - if (canonicalId != null) { - device.setGcmId(canonicalId); - accounts.update(account); - } - - } catch (NotPushRegisteredException e) { - logger.debug("No Such User", e); - device.setGcmId(null); - accounts.update(account); - throw new NotPushRegisteredException(e); - } + gcmSender.sendMessage(number, deviceId, registrationId, pendingMessage); } - private void sendApnMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) - throws TransientPushFailureException, NotPushRegisteredException + private void sendApnMessage(Account account, Device device, PendingMessage outgoingMessage) + throws TransientPushFailureException { - try { - apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage); - } catch (NotPushRegisteredException e) { - device.setApnId(null); - accounts.update(account); - throw new NotPushRegisteredException(e); - } + apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage); } - private void sendWebSocketMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) - throws NotPushRegisteredException + private void sendWebSocketMessage(Account account, Device device, PendingMessage outgoingMessage) { - try { - webSocketSender.sendMessage(account, device, outgoingMessage); - } catch (CryptoEncodingException e) { - throw new NotPushRegisteredException(e); - } + webSocketSender.sendMessage(account, device, outgoingMessage); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index c549da4b3..612772e94 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -19,8 +19,12 @@ package org.whispersystems.textsecuregcm.push; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; -import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; -import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.controllers.WebsocketController; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.PubSubManager; @@ -29,16 +33,18 @@ import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; -import java.util.List; - import static com.codahale.metrics.MetricRegistry.name; public class WebsocketSender { + private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class); + private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online")); private final Meter offlineMeter = metricRegistry.meter(name(getClass(), "offline")); + private static final ObjectMapper mapper = new ObjectMapper(); + private final StoredMessages storedMessages; private final PubSubManager pubSubManager; @@ -47,22 +53,21 @@ public class WebsocketSender { this.pubSubManager = pubSubManager; } - public void sendMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) - throws CryptoEncodingException - { - sendMessage(account, device, outgoingMessage.serialize()); - } + public void sendMessage(Account account, Device device, PendingMessage pendingMessage) { + try { + String serialized = mapper.writeValueAsString(pendingMessage); + WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId()); + PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized); - private void sendMessage(Account account, Device device, String serializedMessage) { - WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId()); - PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage); - - if (pubSubManager.publish(address, pubSubMessage)) { - onlineMeter.mark(); - } else { - offlineMeter.mark(); - storedMessages.insert(account.getId(), device.getId(), serializedMessage); - pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); + if (pubSubManager.publish(address, pubSubMessage)) { + onlineMeter.mark(); + } else { + offlineMeter.mark(); + storedMessages.insert(account.getId(), device.getId(), pendingMessage); + pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); + } + } catch (JsonProcessingException e) { + logger.warn("WebsocketSender", "Unable to serialize json", e); } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java index b954e4cda..3a52da634 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java @@ -19,9 +19,15 @@ package org.whispersystems.textsecuregcm.storage; import com.codahale.metrics.Histogram; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.util.Constants; +import java.io.IOException; import java.util.LinkedList; import java.util.List; @@ -31,9 +37,13 @@ import redis.clients.jedis.JedisPool; public class StoredMessages { + private static final Logger logger = LoggerFactory.getLogger(StoredMessages.class); + private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Histogram queueSizeHistogram = metricRegistry.histogram(name(getClass(), "queue_size")); + + private static final ObjectMapper mapper = new ObjectMapper(); private static final String QUEUE_PREFIX = "msgs"; private final JedisPool jedisPool; @@ -42,34 +52,42 @@ public class StoredMessages { this.jedisPool = jedisPool; } - public void insert(long accountId, long deviceId, String message) { + public void insert(long accountId, long deviceId, PendingMessage message) { Jedis jedis = null; try { jedis = jedisPool.getResource(); - long queueSize = jedis.lpush(getKey(accountId, deviceId), message); + String serializedMessage = mapper.writeValueAsString(message); + long queueSize = jedis.lpush(getKey(accountId, deviceId), serializedMessage); queueSizeHistogram.update(queueSize); if (queueSize > 1000) { jedis.ltrim(getKey(accountId, deviceId), 0, 999); } + + } catch (JsonProcessingException e) { + logger.warn("StoredMessages", "Unable to store correctly", e); } finally { if (jedis != null) jedisPool.returnResource(jedis); } } - public List getMessagesForDevice(long accountId, long deviceId) { - List messages = new LinkedList<>(); - Jedis jedis = null; + public List getMessagesForDevice(long accountId, long deviceId) { + List messages = new LinkedList<>(); + Jedis jedis = null; try { jedis = jedisPool.getResource(); String message; while ((message = jedis.rpop(getKey(accountId, deviceId))) != null) { - messages.add(message); + try { + messages.add(mapper.readValue(message, PendingMessage.class)); + } catch (IOException e) { + logger.warn("StoredMessages", "Not a valid PendingMessage", e); + } } return messages; diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java index 8b9b960f6..b1e5c256b 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java +++ b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java @@ -83,4 +83,40 @@ public class Util { return result; } + + public static byte[][] split(byte[] input, int firstLength, int secondLength) { + byte[][] parts = new byte[2][]; + + parts[0] = new byte[firstLength]; + System.arraycopy(input, 0, parts[0], 0, firstLength); + + parts[1] = new byte[secondLength]; + System.arraycopy(input, firstLength, parts[1], 0, secondLength); + + return parts; + } + + public static byte[][] split(byte[] input, int firstLength, int secondLength, int thirdLength, int fourthLength) { + byte[][] parts = new byte[4][]; + + parts[0] = new byte[firstLength]; + System.arraycopy(input, 0, parts[0], 0, firstLength); + + parts[1] = new byte[secondLength]; + System.arraycopy(input, firstLength, parts[1], 0, secondLength); + + parts[2] = new byte[thirdLength]; + System.arraycopy(input, firstLength + secondLength, parts[2], 0, thirdLength); + + parts[3] = new byte[fourthLength]; + System.arraycopy(input, firstLength + secondLength + thirdLength, parts[3], 0, fourthLength); + + return parts; + } + + public static void sleep(int i) { + try { + Thread.sleep(i); + } catch (InterruptedException ie) {} + } } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java index 296ef0331..f1006c2fb 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java @@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.controllers.WebsocketController; import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; +import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -83,10 +84,10 @@ public class WebsocketControllerTest { public void testOpen() throws Exception { RemoteEndpoint remote = mock(RemoteEndpoint.class); - List outgoingMessages = new LinkedList() {{ - add("first"); - add("second"); - add("third"); + List outgoingMessages = new LinkedList() {{ + add(new PendingMessage("sender1", 1111, "first")); + add(new PendingMessage("sender1", 2222, "second")); + add(new PendingMessage("sender2", 3333, "third")); }}; when(device.getId()).thenReturn(2L); @@ -103,7 +104,8 @@ public class WebsocketControllerTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) .thenReturn(Optional.of(account)); - when(storedMessages.getMessagesForDevice(account.getId(), device.getId())).thenReturn(outgoingMessages); + when(storedMessages.getMessagesForDevice(account.getId(), device.getId())) + .thenReturn(outgoingMessages); WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, pushSender, storedMessages, pubSubManager); WebsocketController controller = (WebsocketController) factory.createWebSocket(null, null); @@ -116,12 +118,13 @@ public class WebsocketControllerTest { controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1))); controller.onWebSocketClose(1000, "Closed"); - List pending = new LinkedList() {{ - add("first"); - add("third"); + List pending = new LinkedList() {{ + add(new PendingMessage("sender1", 1111, "first")); + add(new PendingMessage("sender2", 3333, "third")); }}; - verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(EncryptedOutgoingMessage.class)); + + verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(PendingMessage.class)); } }