Pull GCM/APN senders into service

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2017-04-25 18:03:07 -07:00
parent 28939e7405
commit 189d95f4fa
19 changed files with 998 additions and 510 deletions

20
pom.xml
View File

@ -13,6 +13,7 @@
<properties>
<dropwizard.version>0.9.2</dropwizard.version>
<jackson.api.version>2.6.0</jackson.api.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
@ -103,6 +104,17 @@
<version>0.1.1</version>
</dependency>
<dependency>
<groupId>com.notnoop.apns</groupId>
<artifactId>apns</artifactId>
<version>0.2.3</version>
</dependency>
<dependency>
<groupId>org.whispersystems</groupId>
<artifactId>gcm-sender-async</artifactId>
<version>0.1.4</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.test-framework.providers</groupId>
<artifactId>jersey-test-framework-provider-grizzly2</artifactId>
@ -120,6 +132,8 @@
</exclusions>
</dependency>
</dependencies>
<dependencyManagement>
@ -134,6 +148,12 @@
<artifactId>httpcore</artifactId>
<version>4.4.1</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.api.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

View File

@ -17,7 +17,9 @@
package org.whispersystems.textsecuregcm;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.FederationConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.GraphiteConfiguration;
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.PushConfiguration;
@ -117,6 +119,16 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private TurnConfiguration turn;
@Valid
@NotNull
@JsonProperty
private GcmConfiguration gcm;
@Valid
@NotNull
@JsonProperty
private ApnConfiguration apn;
public WebsocketConfiguration getWebsocketConfiguration() {
return websocket;
@ -174,6 +186,14 @@ public class WhisperServerConfiguration extends Configuration {
return turn;
}
public GcmConfiguration getGcmConfiguration() {
return gcm;
}
public ApnConfiguration getApnConfiguration() {
return apn;
}
public Map<String, Integer> getTestDevices() {
Map<String, Integer> results = new HashMap<>();
@ -195,4 +215,5 @@ public class WhisperServerConfiguration extends Configuration {
return results;
}
}

View File

@ -60,10 +60,10 @@ import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisHealthCheck;
import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.FeedbackHandler;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.PushServiceClient;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.sms.SmsSender;
@ -89,7 +89,6 @@ import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
import org.whispersystems.textsecuregcm.workers.DirectoryCommand;
import org.whispersystems.textsecuregcm.workers.PeriodicStatsCommand;
import org.whispersystems.textsecuregcm.workers.PushCommand;
import org.whispersystems.textsecuregcm.workers.TrimMessagesCommand;
import org.whispersystems.textsecuregcm.workers.VacuumCommand;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@ -124,7 +123,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new TrimMessagesCommand());
bootstrap.addCommand(new PeriodicStatsCommand());
bootstrap.addCommand(new DeleteUserCommand());
bootstrap.addCommand(new PushCommand());
bootstrap.addBundle(new NameableMigrationsBundle<WhisperServerConfiguration>("accountdb", "accountsdb.xml") {
@Override
public DataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) {
@ -178,25 +176,24 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);
DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.<DispatchChannel>of(deadLetterHandler));
PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
PushServiceClient pushServiceClient = new PushServiceClient(httpClient, config.getPushConfiguration());
APNSender apnSender = new APNSender(accountsManager, cacheClient, config.getApnConfiguration());
GCMSender gcmSender = new GCMSender(accountsManager, config.getGcmConfiguration().getApiKey());
WebsocketSender websocketSender = new WebsocketSender(messagesManager, pubSubManager);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager );
FederatedPeerAuthenticator federatedPeerAuthenticator = new FederatedPeerAuthenticator(config.getFederationConfiguration());
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), cacheClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(apnSender, pubSubManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender);
UrlSigner urlSigner = new UrlSigner(config.getS3Configuration());
PushSender pushSender = new PushSender(apnFallbackManager, pushServiceClient, websocketSender, config.getPushConfiguration().getQueueSize());
PushSender pushSender = new PushSender(apnFallbackManager, gcmSender, apnSender, websocketSender, config.getPushConfiguration().getQueueSize());
ReceiptSender receiptSender = new ReceiptSender(accountsManager, pushSender, federatedClientManager);
FeedbackHandler feedbackHandler = new FeedbackHandler(pushServiceClient, accountsManager);
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(config.getTurnConfiguration());
Optional<byte[]> authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey();
environment.lifecycle().manage(apnFallbackManager);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(feedbackHandler);
environment.lifecycle().manage(pushSender);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);

View File

@ -0,0 +1,70 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
public class ApnConfiguration {
@NotEmpty
@JsonProperty
private String pushCertificate;
@NotEmpty
@JsonProperty
private String pushKey;
@NotEmpty
@JsonProperty
private String voipCertificate;
@NotEmpty
@JsonProperty
private String voipKey;
@JsonProperty
private boolean feedback = true;
@JsonProperty
private boolean sandbox = false;
public String getPushCertificate() {
return pushCertificate;
}
public String getPushKey() {
return pushKey;
}
public String getVoipCertificate() {
return voipCertificate;
}
public String getVoipKey() {
return voipKey;
}
public boolean isFeedbackEnabled() {
return feedback;
}
public boolean isSandboxEnabled() {
return sandbox;
}
}

View File

@ -0,0 +1,42 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
public class GcmConfiguration {
@NotNull
@JsonProperty
private long senderId;
@NotEmpty
@JsonProperty
private String apiKey;
public String getApiKey() {
return apiKey;
}
public long getSenderId() {
return senderId;
}
}

View File

@ -1,42 +0,0 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.Min;
public class GcmMessage {
@JsonProperty
@NotEmpty
private String gcmId;
@JsonProperty
@NotEmpty
private String number;
@JsonProperty
@Min(1)
private int deviceId;
@JsonProperty
private String message;
@JsonProperty
private boolean receipt;
@JsonProperty
private boolean notification;
public GcmMessage() {}
public GcmMessage(String gcmId, String number, int deviceId, String message, boolean receipt, boolean notification) {
this.gcmId = gcmId;
this.number = number;
this.deviceId = deviceId;
this.message = message;
this.receipt = receipt;
this.notification = notification;
}
}

View File

@ -0,0 +1,254 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.push;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.notnoop.apns.APNS;
import com.notnoop.apns.ApnsService;
import com.notnoop.apns.ApnsServiceBuilder;
import com.notnoop.exceptions.NetworkIOException;
import org.bouncycastle.openssl.PEMReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
public class APNSender implements Managed {
private final Logger logger = LoggerFactory.getLogger(APNSender.class);
private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
private final AccountsManager accountsManager;
private final JedisPool jedisPool;
private final String pushCertificate;
private final String pushKey;
private final String voipCertificate;
private final String voipKey;
private final boolean feedbackEnabled;
private final boolean sandbox;
private ApnsService pushApnService;
private ApnsService voipApnService;
public APNSender(AccountsManager accountsManager,
JedisPool jedisPool,
ApnConfiguration configuration)
{
this.accountsManager = accountsManager;
this.jedisPool = jedisPool;
this.pushCertificate = configuration.getPushCertificate();
this.pushKey = configuration.getPushKey();
this.voipCertificate = configuration.getVoipCertificate();
this.voipKey = configuration.getVoipKey();
this.feedbackEnabled = configuration.isFeedbackEnabled();
this.sandbox = configuration.isSandboxEnabled();
}
@VisibleForTesting
public APNSender(AccountsManager accountsManager, JedisPool jedisPool,
ApnsService pushApnService, ApnsService voipApnService,
boolean feedbackEnabled, boolean sandbox)
{
this.accountsManager = accountsManager;
this.jedisPool = jedisPool;
this.pushApnService = pushApnService;
this.voipApnService = voipApnService;
this.feedbackEnabled = feedbackEnabled;
this.sandbox = sandbox;
this.pushCertificate = null;
this.pushKey = null;
this.voipCertificate = null;
this.voipKey = null;
}
public void sendMessage(ApnMessage message)
throws TransientPushFailureException
{
try {
redisSet(message.getApnId(), message.getNumber(), message.getDeviceId());
if (message.isVoip()) {
voipApnService.push(message.getApnId(), message.getMessage(), new Date(message.getExpirationTime()));
} else {
pushApnService.push(message.getApnId(), message.getMessage(), new Date(message.getExpirationTime()));
}
} catch (NetworkIOException nioe) {
logger.warn("Network Error", nioe);
throw new TransientPushFailureException(nioe);
}
}
private static byte[] initializeKeyStore(String pemCertificate, String pemKey)
throws KeyStoreException, CertificateException, NoSuchAlgorithmException, IOException
{
PEMReader reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemCertificate.getBytes())));
X509Certificate certificate = (X509Certificate) reader.readObject();
Certificate[] certificateChain = {certificate};
reader = new PEMReader(new InputStreamReader(new ByteArrayInputStream(pemKey.getBytes())));
KeyPair keyPair = (KeyPair) reader.readObject();
KeyStore keyStore = KeyStore.getInstance("pkcs12");
keyStore.load(null);
keyStore.setEntry("apn",
new KeyStore.PrivateKeyEntry(keyPair.getPrivate(), certificateChain),
new KeyStore.PasswordProtection("insecure".toCharArray()));
ByteArrayOutputStream baos = new ByteArrayOutputStream();
keyStore.store(baos, "insecure".toCharArray());
return baos.toByteArray();
}
@Override
public void start() throws Exception {
byte[] pushKeyStore = initializeKeyStore(pushCertificate, pushKey);
byte[] voipKeyStore = initializeKeyStore(voipCertificate, voipKey);
ApnsServiceBuilder pushApnServiceBuilder = APNS.newService()
.withCert(new ByteArrayInputStream(pushKeyStore), "insecure")
.asQueued();
ApnsServiceBuilder voipApnServiceBuilder = APNS.newService()
.withCert(new ByteArrayInputStream(voipKeyStore), "insecure")
.asQueued();
if (sandbox) {
this.pushApnService = pushApnServiceBuilder.withSandboxDestination().build();
this.voipApnService = voipApnServiceBuilder.withSandboxDestination().build();
} else {
this.pushApnService = pushApnServiceBuilder.withProductionDestination().build();
this.voipApnService = voipApnServiceBuilder.withProductionDestination().build();
}
if (feedbackEnabled) {
this.executor.scheduleAtFixedRate(new FeedbackRunnable(), 0, 1, TimeUnit.HOURS);
}
}
@Override
public void stop() throws Exception {
pushApnService.stop();
voipApnService.stop();
}
private void redisSet(String registrationId, String number, int deviceId) {
try (Jedis jedis = jedisPool.getResource()) {
jedis.set("APN-" + registrationId.toLowerCase(), number + "." + deviceId);
jedis.expire("APN-" + registrationId.toLowerCase(), (int) TimeUnit.HOURS.toSeconds(1));
}
}
private Optional<String> redisGet(String registrationId) {
try (Jedis jedis = jedisPool.getResource()) {
String number = jedis.get("APN-" + registrationId.toLowerCase());
return Optional.fromNullable(number);
}
}
@VisibleForTesting
public void checkFeedback() {
new FeedbackRunnable().run();
}
private class FeedbackRunnable implements Runnable {
@Override
public void run() {
try {
Map<String, Date> inactiveDevices = pushApnService.getInactiveDevices();
inactiveDevices.putAll(voipApnService.getInactiveDevices());
for (String registrationId : inactiveDevices.keySet()) {
Optional<String> device = redisGet(registrationId);
if (device.isPresent()) {
logger.warn("Got APN unregistered notice!");
String[] parts = device.get().split("\\.", 2);
if (parts.length == 2) {
String number = parts[0];
int deviceId = Integer.parseInt(parts[1]);
long timestamp = inactiveDevices.get(registrationId).getTime();
handleApnUnregistered(registrationId, number, deviceId, timestamp);
} else {
logger.warn("APN unregister event for device with no parts: " + device.get());
}
} else {
logger.warn("APN unregister event received for uncached ID: " + registrationId);
}
}
} catch (Throwable t) {
logger.warn("Exception during feedback", t);
}
}
private void handleApnUnregistered(String registrationId, String number, int deviceId, long timestamp) {
logger.info("Got APN Unregistered: " + number + "," + deviceId);
Optional<Account> account = accountsManager.get(number);
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(deviceId);
if (device.isPresent()) {
if (registrationId.equals(device.get().getApnId())) {
logger.info("APN Unregister APN ID matches!");
if (device.get().getPushTimestamp() == 0 ||
timestamp > device.get().getPushTimestamp())
{
logger.info("APN Unregister timestamp matches!");
device.get().setApnId(null);
accountsManager.update(account.get());
}
}
}
}
}
}
}

View File

@ -10,7 +10,6 @@ import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.Constants;
@ -43,12 +42,12 @@ public class ApnFallbackManager implements Managed, Runnable, DispatchChannel {
private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
private final PushServiceClient pushServiceClient;
private final PubSubManager pubSubManager;
private final APNSender apnSender;
private final PubSubManager pubSubManager;
public ApnFallbackManager(PushServiceClient pushServiceClient, PubSubManager pubSubManager) {
this.pushServiceClient = pushServiceClient;
this.pubSubManager = pubSubManager;
public ApnFallbackManager(APNSender apnSender, PubSubManager pubSubManager) {
this.apnSender = apnSender;
this.pubSubManager = pubSubManager;
}
public void schedule(final WebsocketAddress address, ApnFallbackTask task) {
@ -102,7 +101,7 @@ public class ApnFallbackManager implements Managed, Runnable, DispatchChannel {
pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this);
}
pushServiceClient.send(message);
apnSender.sendMessage(message);
} catch (Throwable e) {
logger.warn("ApnFallbackThread", e);
}

View File

@ -1,40 +1,15 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
package org.whispersystems.textsecuregcm.push;
public class ApnMessage {
public static long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L;
@JsonProperty
@NotEmpty
private String apnId;
@JsonProperty
@NotEmpty
private String number;
@JsonProperty
@Min(1)
private int deviceId;
@JsonProperty
@NotEmpty
private String message;
@JsonProperty
@NotNull
private boolean voip;
@JsonProperty
private long expirationTime;
public ApnMessage() {}
private final String apnId;
private final String number;
private final int deviceId;
private final String message;
private final boolean voip;
private final long expirationTime;
public ApnMessage(String apnId, String number, int deviceId, String message, boolean voip, long expirationTime) {
this.apnId = apnId;
@ -54,23 +29,27 @@ public class ApnMessage {
this.expirationTime = expirationTime;
}
@VisibleForTesting
public String getApnId() {
return apnId;
}
@VisibleForTesting
public boolean isVoip() {
return voip;
}
@VisibleForTesting
public String getMessage() {
return message;
}
@VisibleForTesting
public long getExpirationTime() {
return expirationTime;
}
public String getNumber() {
return number;
}
public int getDeviceId() {
return deviceId;
}
}

View File

@ -1,120 +0,0 @@
package org.whispersystems.textsecuregcm.push;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.UnregisteredEvent;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import io.dropwizard.lifecycle.Managed;
public class FeedbackHandler implements Managed, Runnable {
private final Logger logger = LoggerFactory.getLogger(PushServiceClient.class);
private final PushServiceClient client;
private final AccountsManager accountsManager;
private ScheduledExecutorService executor;
public FeedbackHandler(PushServiceClient client, AccountsManager accountsManager) {
this.client = client;
this.accountsManager = accountsManager;
}
@Override
public void start() throws Exception {
this.executor = Executors.newSingleThreadScheduledExecutor();
this.executor.scheduleAtFixedRate(this, 0, 1, TimeUnit.MINUTES);
}
@Override
public void stop() throws Exception {
if (this.executor != null) {
this.executor.shutdown();
}
}
@Override
public void run() {
logger.info("Checking Push Server feedback...");
try {
List<UnregisteredEvent> gcmFeedback = client.getGcmFeedback();
List<UnregisteredEvent> apnFeedback = client.getApnFeedback();
logger.info("Got GCM feedback: " + gcmFeedback.size());
logger.info("Got APN feedback: " + apnFeedback.size());
for (UnregisteredEvent gcmEvent : gcmFeedback) {
handleGcmUnregistered(gcmEvent);
}
for (UnregisteredEvent apnEvent : apnFeedback) {
handleApnUnregistered(apnEvent);
}
} catch (Throwable t) {
logger.warn("Error retrieving feedback: ", t);
}
}
private void handleGcmUnregistered(UnregisteredEvent event) {
logger.info("Got GCM Unregistered: " + event.getNumber() + "," + event.getDeviceId());
Optional<Account> account = accountsManager.get(event.getNumber());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(event.getDeviceId());
if (device.isPresent()) {
if (event.getRegistrationId().equals(device.get().getGcmId())) {
logger.info("GCM Unregister GCM ID matches!");
if (device.get().getPushTimestamp() == 0 ||
event.getTimestamp() > (device.get().getPushTimestamp() + TimeUnit.SECONDS.toMillis(10)))
{
logger.info("GCM Unregister Timestamp matches!");
if (event.getCanonicalId() != null && !event.getCanonicalId().isEmpty()) {
logger.info("It's a canonical ID update...");
device.get().setGcmId(event.getCanonicalId());
} else {
device.get().setGcmId(null);
device.get().setFetchesMessages(false);
}
accountsManager.update(account.get());
}
}
}
}
}
private void handleApnUnregistered(UnregisteredEvent event) {
logger.info("Got APN Unregistered: " + event.getNumber() + "," + event.getDeviceId());
Optional<Account> account = accountsManager.get(event.getNumber());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(event.getDeviceId());
if (device.isPresent()) {
if (event.getRegistrationId().equals(device.get().getApnId())) {
logger.info("APN Unregister APN ID matches!");
if (device.get().getPushTimestamp() == 0 ||
event.getTimestamp() > device.get().getPushTimestamp())
{
logger.info("APN Unregister timestamp matches!");
device.get().setApnId(null);
accountsManager.update(account.get());
}
}
}
}
}
}

View File

@ -0,0 +1,178 @@
package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.gcm.server.Message;
import org.whispersystems.gcm.server.Result;
import org.whispersystems.gcm.server.Sender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Constants;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
public class GCMSender implements Managed {
private final Logger logger = LoggerFactory.getLogger(GCMSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter success = metricRegistry.meter(name(getClass(), "sent", "success"));
private final Meter failure = metricRegistry.meter(name(getClass(), "sent", "failure"));
private final Meter unregistered = metricRegistry.meter(name(getClass(), "sent", "unregistered"));
private final Meter canonical = metricRegistry.meter(name(getClass(), "sent", "canonical"));
private final Map<String, Meter> outboundMeters = new HashMap<String, Meter>() {{
put("receipt", metricRegistry.meter(name(getClass(), "outbound", "receipt")));
put("notification", metricRegistry.meter(name(getClass(), "outbound", "notification")));
}};
private final AccountsManager accountsManager;
private final Sender signalSender;
private ExecutorService executor;
public GCMSender(AccountsManager accountsManager, String signalKey) {
this.accountsManager = accountsManager;
this.signalSender = new Sender(signalKey, 50);
}
@VisibleForTesting
public GCMSender(AccountsManager accountsManager, Sender sender, ExecutorService executor) {
this.accountsManager = accountsManager;
this.signalSender = sender;
this.executor = executor;
}
public void sendMessage(GcmMessage message) {
Message.Builder builder = Message.newBuilder()
.withDestination(message.getGcmId())
.withPriority("high");
String key = message.isReceipt() ? "receipt" : "notification";
Message request = builder.withDataPart(key, "").build();
ListenableFuture<Result> future = signalSender.send(request, message);
markOutboundMeter(key);
Futures.addCallback(future, new FutureCallback<Result>() {
@Override
public void onSuccess(Result result) {
if (result.isUnregistered() || result.isInvalidRegistrationId()) {
handleBadRegistration(result);
} else if (result.hasCanonicalRegistrationId()) {
handleCanonicalRegistrationId(result);
} else if (!result.isSuccess()) {
handleGenericError(result);
} else {
success.mark();
}
}
@Override
public void onFailure(Throwable throwable) {
logger.warn("GCM Failed: " + throwable);
}
}, executor);
}
@Override
public void start() {
executor = Executors.newSingleThreadExecutor();
}
@Override
public void stop() throws IOException {
this.signalSender.stop();
this.executor.shutdown();
}
private void handleBadRegistration(Result result) {
GcmMessage message = (GcmMessage)result.getContext();
logger.warn("Got GCM unregistered notice! " + message.getGcmId());
Optional<Account> account = getAccountForEvent(message);
if (account.isPresent()) {
Device device = account.get().getDevice(message.getDeviceId()).get();
device.setGcmId(null);
device.setFetchesMessages(false);
accountsManager.update(account.get());
}
unregistered.mark();
}
private void handleCanonicalRegistrationId(Result result) {
GcmMessage message = (GcmMessage)result.getContext();
logger.warn(String.format("Actually received 'CanonicalRegistrationId' ::: (canonical=%s), (original=%s)",
result.getCanonicalRegistrationId(), message.getGcmId()));
Optional<Account> account = getAccountForEvent(message);
if (account.isPresent()) {
Device device = account.get().getDevice(message.getDeviceId()).get();
device.setGcmId(result.getCanonicalRegistrationId());
accountsManager.update(account.get());
}
canonical.mark();
}
private void handleGenericError(Result result) {
GcmMessage message = (GcmMessage)result.getContext();
logger.warn(String.format("Unrecoverable Error ::: (error=%s), (gcm_id=%s), " +
"(destination=%s), (device_id=%d)",
result.getError(), message.getGcmId(), message.getNumber(),
message.getDeviceId()));
failure.mark();
}
private Optional<Account> getAccountForEvent(GcmMessage message) {
Optional<Account> account = accountsManager.get(message.getNumber());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(message.getDeviceId());
if (device.isPresent()) {
if (message.getGcmId().equals(device.get().getGcmId())) {
logger.info("GCM Unregister GCM ID matches!");
if (device.get().getPushTimestamp() == 0 || System.currentTimeMillis() > (device.get().getPushTimestamp() + TimeUnit.SECONDS.toMillis(10)))
{
logger.info("GCM Unregister Timestamp matches!");
return account;
}
}
}
}
return Optional.absent();
}
private void markOutboundMeter(String key) {
Meter meter = outboundMeters.get(key);
if (meter != null) meter.mark();
else logger.warn("Unknown outbound key: " + key);
}
}

View File

@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.push;
public class GcmMessage {
private final String gcmId;
private final String number;
private final int deviceId;
private final boolean receipt;
public GcmMessage(String gcmId, String number, int deviceId, boolean receipt) {
this.gcmId = gcmId;
this.number = number;
this.deviceId = deviceId;
this.receipt = receipt;
}
public String getGcmId() {
return gcmId;
}
public String getNumber() {
return number;
}
public boolean isReceipt() {
return receipt;
}
public int getDeviceId() {
return deviceId;
}
}

View File

@ -1,4 +1,4 @@
/**
/*
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
@ -20,8 +20,6 @@ import com.codahale.metrics.Gauge;
import com.codahale.metrics.SharedMetricRegistries;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.entities.GcmMessage;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus;
import org.whispersystems.textsecuregcm.storage.Account;
@ -44,16 +42,19 @@ public class PushSender implements Managed {
public static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"badge\":%d,\"alert\":{\"loc-key\":\"APN_Message\"}}}";
private final ApnFallbackManager apnFallbackManager;
private final PushServiceClient pushServiceClient;
private final GCMSender gcmSender;
private final APNSender apnSender;
private final WebsocketSender webSocketSender;
private final BlockingThreadPoolExecutor executor;
private final int queueSize;
public PushSender(ApnFallbackManager apnFallbackManager, PushServiceClient pushServiceClient,
public PushSender(ApnFallbackManager apnFallbackManager,
GCMSender gcmSender, APNSender apnSender,
WebsocketSender websocketSender, int queueSize)
{
this.apnFallbackManager = apnFallbackManager;
this.pushServiceClient = pushServiceClient;
this.gcmSender = gcmSender;
this.apnSender = apnSender;
this.webSocketSender = websocketSender;
this.queueSize = queueSize;
this.executor = new BlockingThreadPoolExecutor(50, queueSize);
@ -115,14 +116,10 @@ public class PushSender implements Managed {
}
private void sendGcmNotification(Account account, Device device) {
try {
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), "", false, true);
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), false);
pushServiceClient.send(gcmMessage);
} catch (TransientPushFailureException e) {
logger.warn("SILENT PUSH LOSS", e);
}
gcmSender.sendMessage(gcmMessage);
}
private void sendApnMessage(Account account, Device device, Envelope outgoingMessage, boolean silent) {
@ -153,7 +150,7 @@ public class PushSender implements Managed {
}
try {
pushServiceClient.send(apnMessage);
apnSender.sendMessage(apnMessage);
} catch (TransientPushFailureException e) {
logger.warn("SILENT PUSH LOSS", e);
}
@ -166,12 +163,16 @@ public class PushSender implements Managed {
@Override
public void start() throws Exception {
apnSender.start();
gcmSender.start();
}
@Override
public void stop() throws Exception {
executor.shutdown();
executor.awaitTermination(5, TimeUnit.MINUTES);
apnSender.stop();
gcmSender.stop();
}
}

View File

@ -1,94 +0,0 @@
package org.whispersystems.textsecuregcm.push;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.PushConfiguration;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.entities.GcmMessage;
import org.whispersystems.textsecuregcm.entities.UnregisteredEvent;
import org.whispersystems.textsecuregcm.entities.UnregisteredEventList;
import org.whispersystems.textsecuregcm.util.Base64;
import javax.ws.rs.ProcessingException;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.List;
public class PushServiceClient {
private static final String PUSH_GCM_PATH = "/api/v1/push/gcm";
private static final String PUSH_APN_PATH = "/api/v1/push/apn";
private static final String APN_FEEDBACK_PATH = "/api/v1/feedback/apn";
private static final String GCM_FEEDBACK_PATH = "/api/v1/feedback/gcm";
private final Logger logger = LoggerFactory.getLogger(PushServiceClient.class);
private final Client client;
private final String host;
private final int port;
private final String authorization;
public PushServiceClient(Client client, PushConfiguration config) {
this.client = client;
this.host = config.getHost();
this.port = config.getPort();
this.authorization = getAuthorizationHeader(config.getUsername(), config.getPassword());
}
public void send(GcmMessage message) throws TransientPushFailureException {
sendPush(PUSH_GCM_PATH, message);
}
public void send(ApnMessage message) throws TransientPushFailureException {
sendPush(PUSH_APN_PATH, message);
}
public List<UnregisteredEvent> getGcmFeedback() throws IOException {
return getFeedback(GCM_FEEDBACK_PATH);
}
public List<UnregisteredEvent> getApnFeedback() throws IOException {
return getFeedback(APN_FEEDBACK_PATH);
}
private void sendPush(String path, Object entity) throws TransientPushFailureException {
try {
Response response = client.target("http://" + host + ":" + port)
.path(path)
.request()
.header("Authorization", authorization)
.put(Entity.entity(entity, MediaType.APPLICATION_JSON_TYPE));
if (response.getStatus() != 204 && response.getStatus() != 200) {
logger.warn("PushServer response: " + response.getStatus() + " " + response.getStatusInfo().getReasonPhrase());
throw new TransientPushFailureException("Bad response: " + response.getStatus());
}
} catch (ProcessingException e) {
logger.warn("Push error: ", e);
throw new TransientPushFailureException(e);
}
}
private List<UnregisteredEvent> getFeedback(String path) throws IOException {
try {
UnregisteredEventList unregisteredEvents = client.target("http://" + host + ":" + port)
.path(path)
.request()
.header("Authorization", authorization)
.get(UnregisteredEventList.class);
return unregisteredEvents.getDevices();
} catch (ProcessingException e) {
logger.warn("Request error:", e);
throw new IOException(e);
}
}
private String getAuthorizationHeader(String username, String password) {
return "Basic " + Base64.encodeBytes((username + ":" + password).getBytes());
}
}

View File

@ -1,173 +0,0 @@
package org.whispersystems.textsecuregcm.workers;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.google.common.base.Optional;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.glassfish.jersey.client.ClientProperties;
import org.skife.jdbi.v2.DBI;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.entities.GcmMessage;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.PushServiceClient;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Messages;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.ws.rs.client.Client;
import java.util.List;
import java.util.concurrent.TimeUnit;
import io.dropwizard.Application;
import io.dropwizard.cli.EnvironmentCommand;
import io.dropwizard.client.JerseyClientBuilder;
import io.dropwizard.jdbi.DBIFactory;
import io.dropwizard.setup.Environment;
import redis.clients.jedis.JedisPool;
public class PushCommand extends EnvironmentCommand<WhisperServerConfiguration> {
private final Logger logger = LoggerFactory.getLogger(DirectoryCommand.class);
private static final int LIMIT = 1000;
public PushCommand() {
super(new Application<WhisperServerConfiguration>() {
@Override
public void run(WhisperServerConfiguration configuration, Environment environment)
throws Exception
{
}
}, "push", "send pushes");
}
@Override
public void configure(Subparser subparser) {
super.configure(subparser);
subparser.addArgument("-t", "--time")
.dest("timestamp")
.type(Long.class)
.required(true)
.help("The starting timestamp to notify users from");
subparser.addArgument("-o", "--offset")
.dest("offset")
.type(Integer.class)
.required(true)
.help("The starting offset in the user query");
}
@Override
protected void run(Environment environment, Namespace namespace,
WhisperServerConfiguration configuration)
throws Exception
{
try {
long timestampStart = namespace.getLong("timestamp");
int offset = namespace.getInt("offset");
environment.getObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
DBIFactory dbiFactory = new DBIFactory();
DBI database = dbiFactory.build(environment, configuration.getDataSourceFactory(), "accountdb" );
DBI messagedb = dbiFactory.build(environment, configuration.getMessageStoreConfiguration(), "messagedb");
Accounts accounts = database.onDemand(Accounts.class);
Messages messages = messagedb.onDemand(Messages.class);
JedisPool cacheClient = new RedisClientFactory(configuration.getCacheConfiguration().getUrl()).getRedisClientPool();
JedisPool redisClient = new RedisClientFactory(configuration.getDirectoryConfiguration().getUrl()).getRedisClientPool();
DirectoryManager directory = new DirectoryManager(redisClient);
AccountsManager accountsManager = new AccountsManager(accounts, directory, cacheClient);
Client httpClient = initializeHttpClient(environment, configuration);
PushServiceClient pushServiceClient = new PushServiceClient(httpClient, configuration.getPushConfiguration());
while (true) {
List<Pair<String, Integer>> pendingDestinations = messages.getPendingDestinations(timestampStart, offset, LIMIT);
if (pendingDestinations == null || pendingDestinations.size() == 0) {
break;
}
for (Pair<String, Integer> pendingDestination : pendingDestinations) {
Optional<Account> account = accountsManager.get(pendingDestination.first());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(pendingDestination.second());
if (device.isPresent()) {
if (device.get().getGcmId() != null) {
sendGcm(pushServiceClient, account.get(), device.get());
} else if (device.get().getApnId() != null) {
sendApn(pushServiceClient, account.get(), device.get());
}
} else {
logger.warn("No device found: " + pendingDestination.first() + ", " + pendingDestination.second());
}
} else {
logger.warn("No account found: " + pendingDestination.first());
}
}
logger.warn("Processed " + LIMIT + "...");
offset += LIMIT;
}
logger.warn("Finished!");
} catch (Exception ex) {
logger.warn("Exception", ex);
}
}
private void sendGcm(PushServiceClient pushServiceClient, Account account, Device device) {
try {
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), "", false, true);
logger.warn("Sending GCM: " + account.getNumber());
pushServiceClient.send(gcmMessage);
} catch (TransientPushFailureException e) {
logger.warn("Push failure", e);
}
}
private void sendApn(PushServiceClient pushServiceClient, Account account, Device device) {
if (!Util.isEmpty(device.getVoipApnId())) {
try {
ApnMessage apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(),
String.format(PushSender.APN_PAYLOAD, 1),
true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION));
logger.warn("Sending APN: " + account.getNumber());
pushServiceClient.send(apnMessage);
} catch (TransientPushFailureException e) {
logger.warn("SILENT PUSH LOSS", e);
}
}
}
private Client initializeHttpClient(Environment environment, WhisperServerConfiguration config) {
Client httpClient = new JerseyClientBuilder(environment).using(config.getJerseyClientConfiguration())
.build(getName());
httpClient.property(ClientProperties.CONNECT_TIMEOUT, 1000);
httpClient.property(ClientProperties.READ_TIMEOUT, 1000);
return httpClient;
}
}

View File

@ -0,0 +1,87 @@
package org.whispersystems.textsecuregcm.tests.push;
import com.google.common.base.Optional;
import com.notnoop.apns.ApnsService;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Date;
import java.util.HashMap;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.*;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
public class APNSenderTest {
private static final String DESTINATION_NUMBER = "+14151231234";
private static final String DESTINATION_APN_ID = "foo";
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final JedisPool jedisPool = mock(JedisPool.class);
private final Jedis jedis = mock(Jedis.class);
private final ApnsService voipService = mock(ApnsService.class);
private final ApnsService apnsService = mock(ApnsService.class);
private final Account destinationAccount = mock(Account.class);
private final Device destinationDevice = mock(Device.class );
@Before
public void setup() {
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID);
when(accountsManager.get(DESTINATION_NUMBER)).thenReturn(Optional.of(destinationAccount));
when(jedisPool.getResource()).thenReturn(jedis);
when(jedis.get("APN-" + DESTINATION_APN_ID)).thenReturn(DESTINATION_NUMBER + "." + 1);
when(voipService.getInactiveDevices()).thenReturn(new HashMap<String, Date>() {{
put(DESTINATION_APN_ID, new Date(System.currentTimeMillis()));
}});
when(apnsService.getInactiveDevices()).thenReturn(new HashMap<String, Date>());
}
@Test
public void testSendVoip() throws TransientPushFailureException {
APNSender apnSender = new APNSender(accountsManager, jedisPool, apnsService, voipService, false, false);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30);
apnSender.sendMessage(message);
verify(jedis, times(1)).set(eq("APN-" + DESTINATION_APN_ID.toLowerCase()), eq(DESTINATION_NUMBER + "." + 1));
verify(voipService, times(1)).push(eq(DESTINATION_APN_ID), eq(message.getMessage()), eq(new Date(30)));
verifyNoMoreInteractions(apnsService);
}
@Test
public void testSendApns() throws TransientPushFailureException {
APNSender apnSender = new APNSender(accountsManager, jedisPool, apnsService, voipService, false, false);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", false, 30);
apnSender.sendMessage(message);
verify(jedis, times(1)).set(eq("APN-" + DESTINATION_APN_ID.toLowerCase()), eq(DESTINATION_NUMBER + "." + 1));
verify(apnsService, times(1)).push(eq(DESTINATION_APN_ID), eq(message.getMessage()), eq(new Date(30)));
verifyNoMoreInteractions(voipService);
}
@Test
public void testFeedbackUnregistered() {
APNSender apnSender = new APNSender(accountsManager, jedisPool, apnsService, voipService, false, false);
apnSender.checkFeedback();
verify(jedis, times(1)).get(eq("APN-" +DESTINATION_APN_ID));
verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER));
verify(destinationDevice, times(1)).setApnId(eq((String)null));
verify(accountsManager, times(1)).update(eq(destinationAccount));
}
}

View File

@ -2,10 +2,10 @@ package org.whispersystems.textsecuregcm.tests.push;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.PushServiceClient;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Util;
@ -14,23 +14,21 @@ import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
public class ApnFallbackManagerTest {
@Test
public void testFullFallback() throws Exception {
PushServiceClient pushServiceClient = mock(PushServiceClient.class);
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222223", 1L);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111);
ApnFallbackTask task = new ApnFallbackTask("foo", "voipfoo", message, 500, 0);
APNSender apnSender = mock(APNSender.class );
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222223", 1L);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111);
ApnFallbackTask task = new ApnFallbackTask("foo", "voipfoo", message, 500, 0);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(apnSender, pubSubManager);
apnFallbackManager.start();
apnFallbackManager.schedule(address, task);
@ -38,7 +36,7 @@ public class ApnFallbackManagerTest {
Util.sleep(1100);
ArgumentCaptor<ApnMessage> captor = ArgumentCaptor.forClass(ApnMessage.class);
verify(pushServiceClient, times(2)).send(captor.capture());
verify(apnSender, times(2)).sendMessage(captor.capture());
verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager));
List<ApnMessage> arguments = captor.getAllValues();
@ -56,7 +54,7 @@ public class ApnFallbackManagerTest {
@Test
public void testNoFallback() throws Exception {
PushServiceClient pushServiceClient = mock(PushServiceClient.class);
APNSender pushServiceClient = mock(APNSender.class);
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222222", 1);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);

View File

@ -0,0 +1,128 @@
package org.whispersystems.textsecuregcm.tests.push;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.SettableFuture;
import org.junit.Test;
import org.mockito.Matchers;
import org.whispersystems.gcm.server.Message;
import org.whispersystems.gcm.server.Result;
import org.whispersystems.gcm.server.Sender;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
public class GCMSenderTest {
@Test
public void testSendMessage() {
AccountsManager accountsManager = mock(AccountsManager.class);
Sender sender = mock(Sender.class );
Result successResult = mock(Result.class );
SynchronousExecutorService executorService = new SynchronousExecutorService();
when(successResult.isInvalidRegistrationId()).thenReturn(false);
when(successResult.isUnregistered()).thenReturn(false);
when(successResult.hasCanonicalRegistrationId()).thenReturn(false);
when(successResult.isSuccess()).thenReturn(true);
GcmMessage message = new GcmMessage("foo", "+12223334444", 1, false);
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
SettableFuture<Result> successFuture = SettableFuture.create();
successFuture.set(successResult);
when(sender.send(any(Message.class), Matchers.anyObject())).thenReturn(successFuture);
when(successResult.getContext()).thenReturn(message);
gcmSender.sendMessage(message);
verify(sender, times(1)).send(any(Message.class), eq(message));
}
@Test
public void testSendError() {
String destinationNumber = "+12223334444";
String gcmId = "foo";
AccountsManager accountsManager = mock(AccountsManager.class);
Sender sender = mock(Sender.class );
Result invalidResult = mock(Result.class );
SynchronousExecutorService executorService = new SynchronousExecutorService();
Account destinationAccount = mock(Account.class);
Device destinationDevice = mock(Device.class );
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
when(invalidResult.isInvalidRegistrationId()).thenReturn(true);
when(invalidResult.isUnregistered()).thenReturn(false);
when(invalidResult.hasCanonicalRegistrationId()).thenReturn(false);
when(invalidResult.isSuccess()).thenReturn(true);
GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, false);
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
SettableFuture<Result> invalidFuture = SettableFuture.create();
invalidFuture.set(invalidResult);
when(sender.send(any(Message.class), Matchers.anyObject())).thenReturn(invalidFuture);
when(invalidResult.getContext()).thenReturn(message);
gcmSender.sendMessage(message);
verify(sender, times(1)).send(any(Message.class), eq(message));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(destinationDevice, times(1)).setGcmId(eq((String)null));
}
@Test
public void testCanonicalId() {
String destinationNumber = "+12223334444";
String gcmId = "foo";
String canonicalId = "bar";
AccountsManager accountsManager = mock(AccountsManager.class);
Sender sender = mock(Sender.class );
Result canonicalResult = mock(Result.class );
SynchronousExecutorService executorService = new SynchronousExecutorService();
Account destinationAccount = mock(Account.class);
Device destinationDevice = mock(Device.class );
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
when(canonicalResult.isInvalidRegistrationId()).thenReturn(false);
when(canonicalResult.isUnregistered()).thenReturn(false);
when(canonicalResult.hasCanonicalRegistrationId()).thenReturn(true);
when(canonicalResult.isSuccess()).thenReturn(false);
when(canonicalResult.getCanonicalRegistrationId()).thenReturn(canonicalId);
GcmMessage message = new GcmMessage(gcmId, destinationNumber, 1, false);
GCMSender gcmSender = new GCMSender(accountsManager, sender, executorService);
SettableFuture<Result> invalidFuture = SettableFuture.create();
invalidFuture.set(canonicalResult);
when(sender.send(any(Message.class), Matchers.anyObject())).thenReturn(invalidFuture);
when(canonicalResult.getContext()).thenReturn(message);
gcmSender.sendMessage(message);
verify(sender, times(1)).send(any(Message.class), eq(message));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(destinationDevice, times(1)).setGcmId(eq(canonicalId));
}
}

View File

@ -0,0 +1,111 @@
package org.whispersystems.textsecuregcm.tests.util;
import com.google.common.util.concurrent.SettableFuture;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class SynchronousExecutorService implements ExecutorService {
private boolean shutdown = false;
@Override
public void shutdown() {
shutdown = true;
}
@Override
public List<Runnable> shutdownNow() {
shutdown = true;
return Collections.emptyList();
}
@Override
public boolean isShutdown() {
return shutdown;
}
@Override
public boolean isTerminated() {
return shutdown;
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return true;
}
@Override
public <T> Future<T> submit(Callable<T> task) {
SettableFuture<T> future = null;
try {
future = SettableFuture.create();
future.set(task.call());
} catch (Throwable e) {
future.setException(e);
}
return future;
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
SettableFuture<T> future = SettableFuture.create();
task.run();
future.set(result);
return future;
}
@Override
public Future<?> submit(Runnable task) {
SettableFuture future = SettableFuture.create();
task.run();
future.set(null);
return future;
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
List<Future<T>> results = new LinkedList<>();
for (Callable<T> callable : tasks) {
SettableFuture<T> future = SettableFuture.create();
try {
future.set(callable.call());
} catch (Throwable e) {
future.setException(e);
}
results.add(future);
}
return results;
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException {
return invokeAll(tasks);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
return null;
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return null;
}
@Override
public void execute(Runnable command) {
command.run();
}
}