Support for vpush only retries

This commit is contained in:
Moxie Marlinspike 2018-05-17 12:39:03 -07:00
parent 6652f96349
commit e26e383bd7
25 changed files with 559 additions and 648 deletions

View File

@ -76,6 +76,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty @JsonProperty
private RedisConfiguration directory; private RedisConfiguration directory;
@NotNull
@Valid
@JsonProperty
private RedisConfiguration pushScheduler;
@NotNull @NotNull
@Valid @Valid
@JsonProperty @JsonProperty
@ -170,6 +175,10 @@ public class WhisperServerConfiguration extends Configuration {
return messageCache; return messageCache;
} }
public RedisConfiguration getPushScheduler() {
return pushScheduler;
}
public DataSourceFactory getMessageStoreConfiguration() { public DataSourceFactory getMessageStoreConfiguration() {
return messageStore; return messageStore;
} }

View File

@ -161,9 +161,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl(), config.getCacheConfiguration().getReplicaUrls() ); RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl(), config.getCacheConfiguration().getReplicaUrls() );
RedisClientFactory directoryClientFactory = new RedisClientFactory(config.getDirectoryConfiguration().getUrl(), config.getDirectoryConfiguration().getReplicaUrls() ); RedisClientFactory directoryClientFactory = new RedisClientFactory(config.getDirectoryConfiguration().getUrl(), config.getDirectoryConfiguration().getReplicaUrls() );
RedisClientFactory messagesClientFactory = new RedisClientFactory(config.getMessageCacheConfiguration().getRedisConfiguration().getUrl(), config.getMessageCacheConfiguration().getRedisConfiguration().getReplicaUrls()); RedisClientFactory messagesClientFactory = new RedisClientFactory(config.getMessageCacheConfiguration().getRedisConfiguration().getUrl(), config.getMessageCacheConfiguration().getRedisConfiguration().getReplicaUrls());
RedisClientFactory pushSchedulerClientFactory = new RedisClientFactory(config.getPushScheduler().getUrl(), config.getPushScheduler().getReplicaUrls() );
ReplicatedJedisPool cacheClient = cacheClientFactory.getRedisClientPool(); ReplicatedJedisPool cacheClient = cacheClientFactory.getRedisClientPool();
ReplicatedJedisPool directoryClient = directoryClientFactory.getRedisClientPool(); ReplicatedJedisPool directoryClient = directoryClientFactory.getRedisClientPool();
ReplicatedJedisPool messagesClient = messagesClientFactory.getRedisClientPool(); ReplicatedJedisPool messagesClient = messagesClientFactory.getRedisClientPool();
ReplicatedJedisPool pushSchedulerClient = pushSchedulerClientFactory.getRedisClientPool();
DirectoryManager directory = new DirectoryManager(directoryClient); DirectoryManager directory = new DirectoryManager(directoryClient);
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheClient); PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheClient);
@ -182,7 +185,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FederatedPeerAuthenticator federatedPeerAuthenticator = new FederatedPeerAuthenticator(config.getFederationConfiguration()); FederatedPeerAuthenticator federatedPeerAuthenticator = new FederatedPeerAuthenticator(config.getFederationConfiguration());
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), cacheClient); RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), cacheClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(apnSender, pubSubManager); ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushSchedulerClient, apnSender, accountsManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration()); TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender); SmsSender smsSender = new SmsSender(twilioSmsSender);
UrlSigner urlSigner = new UrlSigner(config.getAttachmentsConfiguration()); UrlSigner urlSigner = new UrlSigner(config.getAttachmentsConfiguration());
@ -200,7 +203,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner); AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager); KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager); MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager, apnFallbackManager);
ProfileController profileController = new ProfileController(rateLimiters , accountsManager, config.getProfilesConfiguration()); ProfileController profileController = new ProfileController(rateLimiters , accountsManager, config.getProfilesConfiguration());
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<Account>() environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<Account>()
@ -227,7 +230,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
/// ///
WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment(environment, config.getWebSocketConfiguration(), 90000); WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment(environment, config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(deviceAuthenticator)); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(deviceAuthenticator));
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, messagesManager, pubSubManager)); webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(pushSender, receiptSender, messagesManager, pubSubManager, apnFallbackManager));
webSocketEnvironment.jersey().register(new KeepAliveController(pubSubManager)); webSocketEnvironment.jersey().register(new KeepAliveController(pubSubManager));
webSocketEnvironment.jersey().register(messageController); webSocketEnvironment.jersey().register(messageController);
webSocketEnvironment.jersey().register(profileController); webSocketEnvironment.jersey().register(profileController);

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2013 Open WhisperSystems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -33,10 +33,12 @@ import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException; import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -75,13 +77,15 @@ public class MessageController {
private final FederatedClientManager federatedClientManager; private final FederatedClientManager federatedClientManager;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final ApnFallbackManager apnFallbackManager;
public MessageController(RateLimiters rateLimiters, public MessageController(RateLimiters rateLimiters,
PushSender pushSender, PushSender pushSender,
ReceiptSender receiptSender, ReceiptSender receiptSender,
AccountsManager accountsManager, AccountsManager accountsManager,
MessagesManager messagesManager, MessagesManager messagesManager,
FederatedClientManager federatedClientManager) FederatedClientManager federatedClientManager,
ApnFallbackManager apnFallbackManager)
{ {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.pushSender = pushSender; this.pushSender = pushSender;
@ -89,6 +93,7 @@ public class MessageController {
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.federatedClientManager = federatedClientManager; this.federatedClientManager = federatedClientManager;
this.apnFallbackManager = apnFallbackManager;
} }
@Timed @Timed
@ -134,6 +139,12 @@ public class MessageController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth Account account) { public OutgoingMessageEntityList getPendingMessages(@Auth Account account) {
assert account.getAuthenticatedDevice().isPresent();
if (!Util.isEmpty(account.getAuthenticatedDevice().get().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get()));
}
return messagesManager.getMessagesForDevice(account.getNumber(), return messagesManager.getMessagesForDevice(account.getNumber(),
account.getAuthenticatedDevice().get().getId()); account.getAuthenticatedDevice().get().getId());
} }
@ -219,7 +230,7 @@ public class MessageController {
messageBuilder.setRelay(source.getRelay().get()); messageBuilder.setRelay(source.getRelay().get());
} }
pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), incomingMessage.isSilent()); pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build());
} catch (NotPushRegisteredException e) { } catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) throw new NoSuchUserException(e); if (destinationDevice.isMaster()) throw new NoSuchUserException(e);
else logger.debug("Not registered", e); else logger.debug("Not registered", e);

View File

@ -17,7 +17,6 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
public class IncomingMessage { public class IncomingMessage {
@ -45,10 +44,6 @@ public class IncomingMessage {
@JsonProperty @JsonProperty
private long timestamp; // deprecated private long timestamp; // deprecated
@JsonProperty
private boolean silent = false;
public String getDestination() { public String getDestination() {
return destination; return destination;
} }
@ -77,7 +72,4 @@ public class IncomingMessage {
return content; return content;
} }
public boolean isSilent() {
return silent;
}
} }

View File

@ -1,4 +1,4 @@
/** /*
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2013 Open WhisperSystems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
@ -16,6 +16,9 @@
*/ */
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.FutureCallback;
@ -25,10 +28,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult; import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.textsecuregcm.util.Constants;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
@ -37,12 +41,17 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
public class APNSender implements Managed { public class APNSender implements Managed {
private final Logger logger = LoggerFactory.getLogger(APNSender.class); private final Logger logger = LoggerFactory.getLogger(APNSender.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter unregisteredEventStale = metricRegistry.meter(name(APNSender.class, "unregistered_event_stale"));
private static final Meter unregisteredEventFresh = metricRegistry.meter(name(APNSender.class, "unregistered_event_fresh"));
private ExecutorService executor; private ExecutorService executor;
private ApnFallbackManager fallbackManager; private ApnFallbackManager fallbackManager;
@ -71,9 +80,7 @@ public class APNSender implements Managed {
this.bundleId = bundleId; this.bundleId = bundleId;
} }
public ListenableFuture<ApnResult> sendMessage(final ApnMessage message) public ListenableFuture<ApnResult> sendMessage(final ApnMessage message) {
throws TransientPushFailureException
{
String topic = bundleId; String topic = bundleId;
if (message.isVoip()) { if (message.isVoip()) {
@ -106,13 +113,13 @@ public class APNSender implements Managed {
} }
@Override @Override
public void start() throws Exception { public void start() {
this.executor = Executors.newSingleThreadExecutor(); this.executor = Executors.newSingleThreadExecutor();
this.apnsClient.connect(sandbox); this.apnsClient.connect(sandbox);
} }
@Override @Override
public void stop() throws Exception { public void stop() {
this.executor.shutdown(); this.executor.shutdown();
this.apnsClient.disconnect(); this.apnsClient.disconnect();
} }
@ -121,13 +128,14 @@ public class APNSender implements Managed {
this.fallbackManager = fallbackManager; this.fallbackManager = fallbackManager;
} }
private void handleUnregisteredUser(String registrationId, String number, int deviceId) { private void handleUnregisteredUser(String registrationId, String number, long deviceId) {
logger.info("Got APN Unregistered: " + number + "," + deviceId); // logger.info("Got APN Unregistered: " + number + "," + deviceId);
Optional<Account> account = accountsManager.get(number); Optional<Account> account = accountsManager.get(number);
if (!account.isPresent()) { if (!account.isPresent()) {
logger.info("No account found: " + number); logger.info("No account found: " + number);
unregisteredEventStale.mark();
return; return;
} }
@ -135,6 +143,7 @@ public class APNSender implements Managed {
if (!device.isPresent()) { if (!device.isPresent()) {
logger.info("No device found: " + number); logger.info("No device found: " + number);
unregisteredEventStale.mark();
return; return;
} }
@ -142,24 +151,26 @@ public class APNSender implements Managed {
!registrationId.equals(device.get().getVoipApnId())) !registrationId.equals(device.get().getVoipApnId()))
{ {
logger.info("Registration ID does not match: " + registrationId + ", " + device.get().getApnId() + ", " + device.get().getVoipApnId()); logger.info("Registration ID does not match: " + registrationId + ", " + device.get().getApnId() + ", " + device.get().getVoipApnId());
unregisteredEventStale.mark();
return; return;
} }
if (registrationId.equals(device.get().getApnId())) { // if (registrationId.equals(device.get().getApnId())) {
logger.info("APN Unregister APN ID matches! " + number + ", " + deviceId); // logger.info("APN Unregister APN ID matches! " + number + ", " + deviceId);
} else if (registrationId.equals(device.get().getVoipApnId())) { // } else if (registrationId.equals(device.get().getVoipApnId())) {
logger.info("APN Unregister VoIP ID matches! " + number + ", " + deviceId); // logger.info("APN Unregister VoIP ID matches! " + number + ", " + deviceId);
} // }
long tokenTimestamp = device.get().getPushTimestamp(); long tokenTimestamp = device.get().getPushTimestamp();
if (tokenTimestamp != 0 && System.currentTimeMillis() < tokenTimestamp + TimeUnit.SECONDS.toMillis(10)) if (tokenTimestamp != 0 && System.currentTimeMillis() < tokenTimestamp + TimeUnit.SECONDS.toMillis(10))
{ {
logger.info("APN Unregister push timestamp is more recent: " + tokenTimestamp + ", " + number); logger.info("APN Unregister push timestamp is more recent: " + tokenTimestamp + ", " + number);
unregisteredEventStale.mark();
return; return;
} }
logger.info("APN Unregister timestamp matches: " + device.get().getApnId() + ", " + device.get().getVoipApnId()); // logger.info("APN Unregister timestamp matches: " + device.get().getApnId() + ", " + device.get().getVoipApnId());
// device.get().setApnId(null); // device.get().setApnId(null);
// device.get().setVoipApnId(null); // device.get().setVoipApnId(null);
// device.get().setFetchesMessages(false); // device.get().setFetchesMessages(false);
@ -168,5 +179,10 @@ public class APNSender implements Managed {
// if (fallbackManager != null) { // if (fallbackManager != null) {
// fallbackManager.cancel(new WebsocketAddress(number, deviceId)); // fallbackManager.cancel(new WebsocketAddress(number, deviceId));
// } // }
if (fallbackManager != null) {
RedisOperation.unchecked(() -> fallbackManager.cancel(account.get(), device.get()));
unregisteredEventFresh.mark();
}
} }
} }

View File

@ -1,237 +1,247 @@
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.RatioGauge; import com.codahale.metrics.RatioGauge;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.redis.RedisException;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Iterator; import java.io.IOException;
import java.util.LinkedHashMap; import java.util.Arrays;
import java.util.Map.Entry; import java.util.Collections;
import java.util.concurrent.TimeUnit; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.exceptions.JedisException;
public class ApnFallbackManager implements Managed, Runnable, DispatchChannel { @SuppressWarnings("Guava")
public class ApnFallbackManager implements Managed, Runnable {
private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class); private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class);
public static final int FALLBACK_DURATION = 15; private static final String PENDING_NOTIFICATIONS_KEY = "PENDING_APN";
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter voipOneSuccess = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_success")); private static final Meter delivered = metricRegistry.meter(name(ApnFallbackManager.class, "voip_delivered"));
private static final Meter voipOneDelivery = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_failure")); private static final Meter sent = metricRegistry.meter(name(ApnFallbackManager.class, "voip_sent" ));
private static final Histogram voipOneSuccessHistogram = metricRegistry.histogram(name(ApnFallbackManager.class, "voip_one_success_histogram")); private static final Meter retry = metricRegistry.meter(name(ApnFallbackManager.class, "voip_retry"));
static { static {
metricRegistry.register(name(ApnFallbackManager.class, "voip_one_success_ratio"), new VoipRatioGauge(voipOneSuccess, voipOneDelivery)); metricRegistry.register(name(ApnFallbackManager.class, "voip_ratio"), new VoipRatioGauge(delivered, sent));
} }
private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
private final APNSender apnSender; private final APNSender apnSender;
private final PubSubManager pubSubManager; private final AccountsManager accountsManager;
public ApnFallbackManager(APNSender apnSender, PubSubManager pubSubManager) { private final ReplicatedJedisPool jedisPool;
private final InsertOperation insertOperation;
private final GetOperation getOperation;
private final RemoveOperation removeOperation;
private AtomicBoolean running = new AtomicBoolean(false);
private boolean finished;
public ApnFallbackManager(ReplicatedJedisPool jedisPool,
APNSender apnSender,
AccountsManager accountsManager)
throws IOException
{
this.apnSender = apnSender; this.apnSender = apnSender;
this.pubSubManager = pubSubManager; this.accountsManager = accountsManager;
this.jedisPool = jedisPool;
this.insertOperation = new InsertOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
} }
public void schedule(final WebsocketAddress address, ApnFallbackTask task) { public void schedule(Account account, Device device) throws RedisException {
voipOneDelivery.mark(); try {
sent.mark();
if (taskQueue.put(address, task)) { insertOperation.insert(account, device, System.currentTimeMillis() + (15 * 1000), (15 * 1000));
pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); } catch (JedisException e) {
throw new RedisException(e);
} }
} }
private void scheduleRetry(final WebsocketAddress address, ApnFallbackTask task) { public boolean isScheduled(Account account, Device device) throws RedisException {
if (taskQueue.putIfMissing(address, task)) { try {
pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
try (Jedis jedis = jedisPool.getReadResource()) {
return jedis.zscore(PENDING_NOTIFICATIONS_KEY, endpoint) != null;
}
} catch (JedisException e) {
throw new RedisException(e);
} }
} }
public void cancel(WebsocketAddress address) { public void cancel(Account account, Device device) throws RedisException {
ApnFallbackTask task = taskQueue.remove(address); try {
if (removeOperation.remove(account, device)) {
if (task != null) { delivered.mark();
pubSubManager.unsubscribe(new WebSocketConnectionInfo(address), this); }
voipOneSuccess.mark(); } catch (JedisException e) {
voipOneSuccessHistogram.update(System.currentTimeMillis() - task.getScheduledTime()); throw new RedisException(e);
} }
} }
@Override @Override
public void start() throws Exception { public synchronized void start() {
running.set(true);
new Thread(this).start(); new Thread(this).start();
} }
@Override @Override
public void stop() throws Exception { public synchronized void stop() {
running.set(false);
while (!finished) Util.wait(this);
} }
@Override @Override
public void run() { public void run() {
while (true) { while (running.get()) {
try { try {
Entry<WebsocketAddress, ApnFallbackTask> taskEntry = taskQueue.get(); List<byte[]> pendingNotifications = getOperation.getPending(100);
ApnFallbackTask task = taskEntry.getValue();
ApnMessage message; for (byte[] pendingNotification : pendingNotifications) {
String numberAndDevice = new String(pendingNotification);
Optional<Pair<String, Long>> separated = getSeparated(numberAndDevice);
if (task.getAttempt() == 0) { if (!separated.isPresent()) {
message = new ApnMessage(task.getMessage(), task.getVoipApnId(), true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(FALLBACK_DURATION)); removeOperation.remove(numberAndDevice);
scheduleRetry(taskEntry.getKey(), new ApnFallbackTask(task.getApnId(), task.getVoipApnId(), task.getMessage(), task.getDelay(),1)); continue;
} else {
message = new ApnMessage(task.getMessage(), task.getApnId(), false, ApnMessage.MAX_EXPIRATION);
pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this);
} }
apnSender.sendMessage(message); Optional<Account> account = accountsManager.get(separated.get().first());
} catch (Throwable e) {
logger.warn("ApnFallbackThread", e); if (!account.isPresent()) {
removeOperation.remove(numberAndDevice);
continue;
} }
Optional<Device> device = account.get().getDevice(separated.get().second());
if (!device.isPresent()) {
removeOperation.remove(numberAndDevice);
continue;
}
String apnId = device.get().getVoipApnId();
if (apnId == null) {
removeOperation.remove(account.get(), device.get());
continue;
}
apnSender.sendMessage(new ApnMessage(apnId, separated.get().first(), separated.get().second(), true));
retry.mark();
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
Util.sleep(1000);
}
synchronized (ApnFallbackManager.this) {
finished = true;
notifyAll();
} }
} }
@Override private Optional<Pair<String, Long>> getSeparated(String encoded) {
public void onDispatchMessage(String channel, byte[] message) {
try { try {
PubSubMessage notification = PubSubMessage.parseFrom(message); if (encoded == null) return Optional.absent();
if (notification.getType().getNumber() == PubSubMessage.Type.CONNECTED_VALUE) { String[] parts = encoded.split(":");
WebSocketConnectionInfo address = new WebSocketConnectionInfo(channel);
cancel(address.getWebsocketAddress()); if (parts.length != 2) {
} else { logger.warn("Got strange encoded number: " + encoded);
logger.warn("Got strange pubsub type: " + notification.getType().getNumber()); return Optional.absent();
} }
} catch (WebSocketConnectionInfo.FormattingException e) { return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1])));
logger.warn("Bad formatting?", e); } catch (NumberFormatException e) {
} catch (InvalidProtocolBufferException e) { logger.warn("Badly formatted: " + encoded, e);
logger.warn("Bad protobuf", e); return Optional.absent();
} }
} }
@Override private static class RemoveOperation {
public void onDispatchSubscribed(String channel) {}
@Override private final LuaScript luaScript;
public void onDispatchUnsubscribed(String channel) {}
public static class ApnFallbackTask { RemoveOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/remove.lua");
private final long delay;
private final long scheduledTime;
private final String apnId;
private final String voipApnId;
private final ApnMessage message;
private final int attempt;
public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message) {
this(apnId, voipApnId, message, TimeUnit.SECONDS.toMillis(FALLBACK_DURATION), 0);
} }
@VisibleForTesting boolean remove(Account account, Device device) {
public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message, long delay, int attempt) { String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
this.scheduledTime = System.currentTimeMillis(); return remove(endpoint);
this.delay = delay;
this.apnId = apnId;
this.voipApnId = voipApnId;
this.message = message;
this.attempt = attempt;
} }
public String getApnId() { boolean remove(String endpoint) {
return apnId; if (!PENDING_NOTIFICATIONS_KEY.equals(endpoint)) {
List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes());
List<byte[]> args = Collections.emptyList();
return ((long)luaScript.execute(keys, args)) > 0;
} }
public String getVoipApnId() { return false;
return voipApnId;
} }
public ApnMessage getMessage() {
return message;
} }
public long getScheduledTime() { private static class GetOperation {
return scheduledTime;
private final LuaScript luaScript;
GetOperation(ReplicatedJedisPool jedisPool) throws IOException {
this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/get.lua");
} }
public long getExecutionTime() { @SuppressWarnings("SameParameterValue")
return scheduledTime + delay; List<byte[]> getPending(int limit) {
} List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes());
List<byte[]> args = Arrays.asList(String.valueOf(System.currentTimeMillis()).getBytes(), String.valueOf(limit).getBytes());
public long getDelay() { return (List<byte[]>) luaScript.execute(keys, args);
return delay;
}
public int getAttempt() {
return attempt;
} }
} }
@VisibleForTesting private static class InsertOperation {
public static class ApnFallbackTaskQueue {
private final LinkedHashMap<WebsocketAddress, ApnFallbackTask> tasks = new LinkedHashMap<>(); private final LuaScript luaScript;
public Entry<WebsocketAddress, ApnFallbackTask> get() { InsertOperation(ReplicatedJedisPool jedisPool) throws IOException {
while (true) { this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/insert.lua");
long timeDelta;
synchronized (tasks) {
while (tasks.isEmpty()) Util.wait(tasks);
Iterator<Entry<WebsocketAddress, ApnFallbackTask>> iterator = tasks.entrySet().iterator();
Entry<WebsocketAddress, ApnFallbackTask> nextTask = iterator.next();
timeDelta = nextTask.getValue().getExecutionTime() - System.currentTimeMillis();
if (timeDelta <= 0) {
iterator.remove();
return nextTask;
}
} }
Util.sleep(timeDelta); public void insert(Account account, Device device, long timestamp, long interval) {
} String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId();
}
public boolean put(WebsocketAddress address, ApnFallbackTask task) { List<byte[]> keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes());
synchronized (tasks) { List<byte[]> args = Arrays.asList(String.valueOf(timestamp).getBytes(), String.valueOf(interval).getBytes(),
ApnFallbackTask previous = tasks.put(address, task); account.getNumber().getBytes(), String.valueOf(device.getId()).getBytes());
tasks.notifyAll();
return previous == null; luaScript.execute(keys, args);
}
}
public boolean putIfMissing(WebsocketAddress address, ApnFallbackTask task) {
synchronized (tasks) {
if (tasks.containsKey(address)) return false;
return put(address, task);
}
}
public ApnFallbackTask remove(WebsocketAddress address) {
synchronized (tasks) {
return tasks.remove(address);
}
} }
} }
@ -247,7 +257,7 @@ public class ApnFallbackManager implements Managed, Runnable, DispatchChannel {
@Override @Override
protected Ratio getRatio() { protected Ratio getRatio() {
return Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate()); return RatioGauge.Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate());
} }
} }

View File

@ -2,31 +2,19 @@ package org.whispersystems.textsecuregcm.push;
public class ApnMessage { public class ApnMessage {
public static long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L; public static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}";
public static final long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L;
private final String apnId; private final String apnId;
private final String number; private final String number;
private final int deviceId; private final long deviceId;
private final String message;
private final boolean isVoip; private final boolean isVoip;
private final long expirationTime;
public ApnMessage(String apnId, String number, int deviceId, String message, boolean isVoip, long expirationTime) { public ApnMessage(String apnId, String number, long deviceId, boolean isVoip) {
this.apnId = apnId; this.apnId = apnId;
this.number = number; this.number = number;
this.deviceId = deviceId; this.deviceId = deviceId;
this.message = message;
this.isVoip = isVoip; this.isVoip = isVoip;
this.expirationTime = expirationTime;
}
public ApnMessage(ApnMessage copy, String apnId, boolean isVoip, long expirationTime) {
this.apnId = apnId;
this.number = copy.number;
this.deviceId = copy.deviceId;
this.message = copy.message;
this.isVoip = isVoip;
this.expirationTime = expirationTime;
} }
public boolean isVoip() { public boolean isVoip() {
@ -38,18 +26,18 @@ public class ApnMessage {
} }
public String getMessage() { public String getMessage() {
return message; return APN_PAYLOAD;
} }
public long getExpirationTime() { public long getExpirationTime() {
return expirationTime; return MAX_EXPIRATION;
} }
public String getNumber() { public String getNumber() {
return number; return number;
} }
public int getDeviceId() { public long getDeviceId() {
return deviceId; return deviceId;
} }
} }

View File

@ -20,14 +20,13 @@ import com.codahale.metrics.Gauge;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus; import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor; import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -37,10 +36,9 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
public class PushSender implements Managed { public class PushSender implements Managed {
@SuppressWarnings("unused")
private final Logger logger = LoggerFactory.getLogger(PushSender.class); private final Logger logger = LoggerFactory.getLogger(PushSender.class);
private static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}";
private final ApnFallbackManager apnFallbackManager; private final ApnFallbackManager apnFallbackManager;
private final GCMSender gcmSender; private final GCMSender gcmSender;
private final APNSender apnSender; private final APNSender apnSender;
@ -64,7 +62,7 @@ public class PushSender implements Managed {
(Gauge<Integer>) executor::getSize); (Gauge<Integer>) executor::getSize);
} }
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean silent) public void sendMessage(final Account account, final Device device, final Envelope message)
throws NotPushRegisteredException throws NotPushRegisteredException
{ {
if (device.getGcmId() == null && device.getApnId() == null && !device.getFetchesMessages()) { if (device.getGcmId() == null && device.getApnId() == null && !device.getFetchesMessages()) {
@ -72,17 +70,17 @@ public class PushSender implements Managed {
} }
if (queueSize > 0) { if (queueSize > 0) {
executor.execute(() -> sendSynchronousMessage(account, device, message, silent)); executor.execute(() -> sendSynchronousMessage(account, device, message));
} else { } else {
sendSynchronousMessage(account, device, message, silent); sendSynchronousMessage(account, device, message);
} }
} }
public void sendQueuedNotification(Account account, Device device, boolean fallback) public void sendQueuedNotification(Account account, Device device)
throws NotPushRegisteredException, TransientPushFailureException throws NotPushRegisteredException
{ {
if (device.getGcmId() != null) sendGcmNotification(account, device); if (device.getGcmId() != null) sendGcmNotification(account, device);
else if (device.getApnId() != null) sendApnNotification(account, device, fallback); else if (device.getApnId() != null) sendApnNotification(account, device, true);
else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!"); else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!");
} }
@ -90,9 +88,9 @@ public class PushSender implements Managed {
return webSocketSender; return webSocketSender;
} }
private void sendSynchronousMessage(Account account, Device device, Envelope message, boolean silent) { private void sendSynchronousMessage(Account account, Device device, Envelope message) {
if (device.getGcmId() != null) sendGcmMessage(account, device, message); if (device.getGcmId() != null) sendGcmMessage(account, device, message);
else if (device.getApnId() != null) sendApnMessage(account, device, message, silent); else if (device.getApnId() != null) sendApnMessage(account, device, message);
else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message);
else throw new AssertionError(); else throw new AssertionError();
} }
@ -112,36 +110,29 @@ public class PushSender implements Managed {
gcmSender.sendMessage(gcmMessage); gcmSender.sendMessage(gcmMessage);
} }
private void sendApnMessage(Account account, Device device, Envelope outgoingMessage, boolean silent) { private void sendApnMessage(Account account, Device device, Envelope outgoingMessage) {
DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN); DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN);
if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) { if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) {
boolean fallback = !silent && !outgoingMessage.getSource().equals(account.getNumber()); sendApnNotification(account, device, false);
sendApnNotification(account, device, fallback);
} }
} }
private void sendApnNotification(Account account, Device device, boolean fallback) { private void sendApnNotification(Account account, Device device, boolean newOnly) {
ApnMessage apnMessage; ApnMessage apnMessage;
if (newOnly && RedisOperation.unchecked(() -> apnFallbackManager.isScheduled(account, device))) {
return;
}
if (!Util.isEmpty(device.getVoipApnId())) { if (!Util.isEmpty(device.getVoipApnId())) {
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, true, apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true);
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION)); RedisOperation.unchecked(() -> apnFallbackManager.schedule(account, device));
if (fallback) {
apnFallbackManager.schedule(new WebsocketAddress(account.getNumber(), device.getId()),
new ApnFallbackTask(device.getApnId(), device.getVoipApnId(), apnMessage));
}
} else { } else {
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false);
false, ApnMessage.MAX_EXPIRATION);
} }
try {
apnSender.sendMessage(apnMessage); apnSender.sendMessage(apnMessage);
} catch (TransientPushFailureException e) {
logger.warn("SILENT PUSH LOSS", e);
}
} }
private void sendWebSocketMessage(Account account, Device device, Envelope outgoingMessage) private void sendWebSocketMessage(Account account, Device device, Envelope outgoingMessage)
@ -163,4 +154,5 @@ public class PushSender implements Managed {
apnSender.stop(); apnSender.stop();
gcmSender.stop(); gcmSender.stop();
} }
} }

View File

@ -72,7 +72,7 @@ public class ReceiptSender {
} }
for (Device destinationDevice : destinationDevices) { for (Device destinationDevice : destinationDevices) {
pushSender.sendMessage(destinationAccount, destinationDevice, message.build(), true); pushSender.sendMessage(destinationAccount, destinationDevice, message.build());
} }
} }

View File

@ -0,0 +1,8 @@
package org.whispersystems.textsecuregcm.redis;
public class RedisException extends Exception {
public RedisException(Exception e) {
super(e);
}
}

View File

@ -0,0 +1,37 @@
package org.whispersystems.textsecuregcm.redis;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
public class RedisOperation {
private static final Logger logger = LoggerFactory.getLogger(RedisOperation.class);
public static void unchecked(Operation operation) {
try {
operation.run();
} catch (RedisException e) {
logger.warn("Jedis failure", e);
}
}
public static boolean unchecked(BooleanOperation operation) {
try {
return operation.run();
} catch (RedisException e) {
logger.warn("Jedis failure", e);
}
return false;
}
@FunctionalInterface
public interface Operation {
public void run() throws RedisException;
}
public interface BooleanOperation {
public boolean run() throws RedisException;
}
}

View File

@ -12,7 +12,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.redis.LuaScript; import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
@ -483,11 +482,9 @@ public class MessagesCache implements Managed {
if (device.isPresent()) { if (device.isPresent()) {
try { try {
pushSender.sendQueuedNotification(account.get(), device.get(), false); pushSender.sendQueuedNotification(account.get(), device.get());
} catch (NotPushRegisteredException e) { } catch (NotPushRegisteredException e) {
logger.warn("After message persistence, no longer push registered!"); logger.warn("After message persistence, no longer push registered!");
} catch (TransientPushFailureException e) {
logger.warn("Transient push failure!", e);
} }
} }
} }

View File

@ -6,16 +6,16 @@ import com.codahale.metrics.Timer;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener; import org.whispersystems.websocket.setup.WebSocketConnectListener;
@ -29,21 +29,23 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration")); private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration"));
private final AccountsManager accountsManager;
private final PushSender pushSender; private final PushSender pushSender;
private final ReceiptSender receiptSender; private final ReceiptSender receiptSender;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final PubSubManager pubSubManager; private final PubSubManager pubSubManager;
private final ApnFallbackManager apnFallbackManager;
public AuthenticatedConnectListener(AccountsManager accountsManager, PushSender pushSender, public AuthenticatedConnectListener(PushSender pushSender,
ReceiptSender receiptSender, MessagesManager messagesManager, ReceiptSender receiptSender,
PubSubManager pubSubManager) MessagesManager messagesManager,
PubSubManager pubSubManager,
ApnFallbackManager apnFallbackManager)
{ {
this.accountsManager = accountsManager;
this.pushSender = pushSender; this.pushSender = pushSender;
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pubSubManager = pubSubManager; this.pubSubManager = pubSubManager;
this.apnFallbackManager = apnFallbackManager;
} }
@Override @Override
@ -53,7 +55,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final String connectionId = String.valueOf(new SecureRandom().nextLong()); final String connectionId = String.valueOf(new SecureRandom().nextLong());
final Timer.Context timer = durationTimer.time(); final Timer.Context timer = durationTimer.time();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device, messagesManager, account, device,
context.getClient(), connectionId); context.getClient(), connectionId);
@ -61,7 +62,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
.setContent(ByteString.copyFrom(connectionId.getBytes())) .setContent(ByteString.copyFrom(connectionId.getBytes()))
.build(); .build();
pubSubManager.publish(info, connectMessage); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device));
pubSubManager.publish(address, connectMessage); pubSubManager.publish(address, connectMessage);
pubSubManager.subscribe(address, connection); pubSubManager.subscribe(address, connection);

View File

@ -20,7 +20,6 @@ public class DeadLetterHandler implements DispatchChannel {
@Override @Override
public void onDispatchMessage(String channel, byte[] data) { public void onDispatchMessage(String channel, byte[] data) {
if (!WebSocketConnectionInfo.isType(channel)) {
try { try {
logger.info("Handling dead letter to: " + channel); logger.info("Handling dead letter to: " + channel);
@ -39,7 +38,6 @@ public class DeadLetterHandler implements DispatchChannel {
logger.warn("Invalid websocket address", e); logger.warn("Invalid websocket address", e);
} }
} }
}
@Override @Override
public void onDispatchSubscribed(String channel) { public void onDispatchSubscribed(String channel) {

View File

@ -150,11 +150,9 @@ public class WebSocketConnection implements DispatchChannel {
private void requeueMessage(Envelope message) { private void requeueMessage(Envelope message) {
pushSender.getWebSocketSender().queueMessage(account, device, message); pushSender.getWebSocketSender().queueMessage(account, device, message);
boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT;
try { try {
pushSender.sendQueuedNotification(account, device, fallback); pushSender.sendQueuedNotification(account, device);
} catch (NotPushRegisteredException | TransientPushFailureException e) { } catch (NotPushRegisteredException e) {
logger.warn("requeueMessage", e); logger.warn("requeueMessage", e);
} }
} }

View File

@ -1,62 +0,0 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.storage.PubSubAddress;
import org.whispersystems.textsecuregcm.util.Util;
public class WebSocketConnectionInfo implements PubSubAddress {
private final WebsocketAddress address;
public WebSocketConnectionInfo(WebsocketAddress address) {
this.address = address;
}
public WebSocketConnectionInfo(String serialized) throws FormattingException {
String[] parts = serialized.split("[:]", 3);
if (parts.length != 3 || !"c".equals(parts[2])) {
throw new FormattingException("Bad address: " + serialized);
}
try {
this.address = new WebsocketAddress(parts[0], Long.parseLong(parts[1]));
} catch (NumberFormatException e) {
throw new FormattingException(e);
}
}
public String serialize() {
return address.serialize() + ":c";
}
public WebsocketAddress getWebsocketAddress() {
return address;
}
public static boolean isType(String address) {
return address.endsWith(":c");
}
@Override
public boolean equals(Object other) {
return
other != null &&
other instanceof WebSocketConnectionInfo
&& ((WebSocketConnectionInfo)other).address.equals(address);
}
@Override
public int hashCode() {
return Util.hashCode(address, "c");
}
public static class FormattingException extends Exception {
public FormattingException(String message) {
super(message);
}
public FormattingException(Exception e) {
super(e);
}
}
}

View File

@ -0,0 +1,66 @@
-- keys: pending (KEYS[1])
-- argv: max_time (ARGV[1]), limit (ARGV[2])
local hgetall = function (key)
local bulk = redis.call('HGETALL', key)
local result = {}
local nextkey
for i, v in ipairs(bulk) do
if i % 2 == 1 then
nextkey = v
else
result[nextkey] = v
end
end
return result
end
local getNextInterval = function(interval)
if interval < 20000 then
return 20000
end
if interval < 40000 then
return 40000
end
if interval < 80000 then
return 80000
end
if interval < 160000 then
return 160000
end
if interval < 600000 then
return 600000
end
return 1800000
end
local results = redis.call("ZRANGEBYSCORE", KEYS[1], 0, ARGV[1], "LIMIT", 0, ARGV[2])
local collated = {}
if results and next(results) then
for i, name in ipairs(results) do
local pending = hgetall(name)
local lastInterval = pending["interval"]
if lastInterval == nil then
lastInterval = 0
end
local nextInterval = getNextInterval(tonumber(lastInterval))
redis.call("HSET", name, "interval", nextInterval)
redis.call("ZADD", KEYS[1], tonumber(ARGV[1]) + nextInterval, name)
collated[i] = pending["account"] .. ":" .. pending["device"]
end
end
return collated

View File

@ -0,0 +1,8 @@
-- keys: pending (KEYS[1]), user (KEYS[2])
-- args: timestamp (ARGV[1]), interval (ARGV[2]), account (ARGV[3]), device (ARGV[4])
redis.call("HSET", KEYS[2], "interval", ARGV[2])
redis.call("HSET", KEYS[2], "account", ARGV[3])
redis.call("HSET", KEYS[2], "device", ARGV[4])
redis.call("ZADD", KEYS[1], ARGV[1], KEYS[2])

View File

@ -0,0 +1,4 @@
-- keys: queue KEYS[1], endpoint (KEYS[2])
redis.call("DEL", KEYS[2])
return redis.call("ZREM", KEYS[1], KEYS[2])

View File

@ -20,6 +20,7 @@ import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -56,13 +57,14 @@ public class FederatedControllerTest {
private MessagesManager messagesManager = mock(MessagesManager.class); private MessagesManager messagesManager = mock(MessagesManager.class);
private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class ); private RateLimiter rateLimiter = mock(RateLimiter.class );
private ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private final SignedPreKey signedPreKey = new SignedPreKey(3333, "foo", "baar"); private final SignedPreKey signedPreKey = new SignedPreKey(3333, "foo", "baar");
private final PreKeyResponse preKeyResponseV2 = new PreKeyResponse("foo", new LinkedList<PreKeyResponseItem>()); private final PreKeyResponse preKeyResponseV2 = new PreKeyResponse("foo", new LinkedList<PreKeyResponseItem>());
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
private final MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager); private final MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager, apnFallbackManager);
private final KeysController keysControllerV2 = mock(KeysController.class); private final KeysController keysControllerV2 = mock(KeysController.class);
@Rule @Rule
@ -112,7 +114,7 @@ public class FederatedControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(204))); assertThat("Good Response", response.getStatus(), is(equalTo(204)));
verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.Envelope.class), eq(false)); verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.Envelope.class));
} }
@Test @Test

View File

@ -18,6 +18,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -58,6 +59,7 @@ public class MessageControllerTest {
private final MessagesManager messagesManager = mock(MessagesManager.class); private final MessagesManager messagesManager = mock(MessagesManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class ); private final RateLimiters rateLimiters = mock(RateLimiters.class );
private final RateLimiter rateLimiter = mock(RateLimiter.class ); private final RateLimiter rateLimiter = mock(RateLimiter.class );
private final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
@ -67,7 +69,7 @@ public class MessageControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder()) .addProvider(new AuthValueFactoryProvider.Binder())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, .addResource(new MessageController(rateLimiters, pushSender, receiptSender, accountsManager,
messagesManager, federatedClientManager)) messagesManager, federatedClientManager, apnFallbackManager))
.build(); .build();
@ -104,7 +106,7 @@ public class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class));
} }
@Test @Test
@ -157,7 +159,7 @@ public class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class));
} }
@Test @Test

View File

@ -20,7 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService; import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Date; import java.util.Date;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -64,7 +63,7 @@ public class APNSenderTest {
.thenReturn(result); .thenReturn(result);
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -75,8 +74,8 @@ public class APNSenderTest {
verify(apnsClient, times(1)).sendNotification(notification.capture()); verify(apnsClient, times(1)).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip"); assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip");
@ -101,7 +100,7 @@ public class APNSenderTest {
.thenReturn(result); .thenReturn(result);
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", false, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -112,8 +111,8 @@ public class APNSenderTest {
verify(apnsClient, times(1)).sendNotification(notification.capture()); verify(apnsClient, times(1)).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo"); assertThat(notification.getValue().getTopic()).isEqualTo("foo");
@ -124,57 +123,57 @@ public class APNSenderTest {
verifyNoMoreInteractions(fallbackManager); verifyNoMoreInteractions(fallbackManager);
} }
// @Test @Test
// public void testUnregisteredUser() throws Exception { public void testUnregisteredUser() throws Exception {
// ApnsClient apnsClient = mock(ApnsClient.class); ApnsClient apnsClient = mock(ApnsClient.class);
//
// PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class); PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
// when(response.isAccepted()).thenReturn(false); when(response.isAccepted()).thenReturn(false);
// when(response.getRejectionReason()).thenReturn("Unregistered"); when(response.getRejectionReason()).thenReturn("Unregistered");
//
// DefaultPromise<PushNotificationResponse<SimpleApnsPushNotification>> result = new DefaultPromise<>(executor); DefaultPromise<PushNotificationResponse<SimpleApnsPushNotification>> result = new DefaultPromise<>(executor);
// result.setSuccess(response); result.setSuccess(response);
//
// when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
// .thenReturn(result); .thenReturn(result);
//
// RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
// ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
// APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
// apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
//
// when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID);
// when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(11)); when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(11));
//
// ListenableFuture<ApnResult> sendFuture = apnSender.sendMessage(message); ListenableFuture<ApnResult> sendFuture = apnSender.sendMessage(message);
// ApnResult apnResult = sendFuture.get(); ApnResult apnResult = sendFuture.get();
//
// Thread.sleep(1000); // =( Thread.sleep(1000); // =(
//
// ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
// verify(apnsClient, times(1)).sendNotification(notification.capture()); verify(apnsClient, times(1)).sendNotification(notification.capture());
//
// assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
// assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
// assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
// assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
//
// assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER);
//
// verifyNoMoreInteractions(apnsClient); verifyNoMoreInteractions(apnsClient);
// verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER));
// verify(destinationAccount, times(1)).getDevice(1); verify(destinationAccount, times(1)).getDevice(1);
// verify(destinationDevice, times(1)).getApnId(); verify(destinationDevice, times(1)).getApnId();
// verify(destinationDevice, times(1)).getPushTimestamp(); verify(destinationDevice, times(1)).getPushTimestamp();
// verify(destinationDevice, times(1)).setApnId(eq((String)null)); // verify(destinationDevice, times(1)).setApnId(eq((String)null));
// verify(destinationDevice, times(1)).setVoipApnId(eq((String)null)); // verify(destinationDevice, times(1)).setVoipApnId(eq((String)null));
// verify(destinationDevice, times(1)).setFetchesMessages(eq(false)); // verify(destinationDevice, times(1)).setFetchesMessages(eq(false));
// verify(accountsManager, times(1)).update(eq(destinationAccount)); // verify(accountsManager, times(1)).update(eq(destinationAccount));
// verify(fallbackManager, times(1)).cancel(eq(new WebsocketAddress(DESTINATION_NUMBER, 1))); verify(fallbackManager, times(1)).cancel(eq(destinationAccount), eq(destinationDevice));
//
// verifyNoMoreInteractions(accountsManager); verifyNoMoreInteractions(accountsManager);
// verifyNoMoreInteractions(fallbackManager); verifyNoMoreInteractions(fallbackManager);
// } }
// @Test // @Test
// public void testVoipUnregisteredUser() throws Exception { // public void testVoipUnregisteredUser() throws Exception {
@ -230,54 +229,54 @@ public class APNSenderTest {
// verifyNoMoreInteractions(fallbackManager); // verifyNoMoreInteractions(fallbackManager);
// } // }
// @Test @Test
// public void testRecentUnregisteredUser() throws Exception { public void testRecentUnregisteredUser() throws Exception {
// ApnsClient apnsClient = mock(ApnsClient.class); ApnsClient apnsClient = mock(ApnsClient.class);
//
// PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class); PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
// when(response.isAccepted()).thenReturn(false); when(response.isAccepted()).thenReturn(false);
// when(response.getRejectionReason()).thenReturn("Unregistered"); when(response.getRejectionReason()).thenReturn("Unregistered");
//
// DefaultPromise<PushNotificationResponse<SimpleApnsPushNotification>> result = new DefaultPromise<>(executor); DefaultPromise<PushNotificationResponse<SimpleApnsPushNotification>> result = new DefaultPromise<>(executor);
// result.setSuccess(response); result.setSuccess(response);
//
// when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
// .thenReturn(result); .thenReturn(result);
//
// RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
// ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
// APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
// apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
//
// when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID);
// when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis()); when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis());
//
// ListenableFuture<ApnResult> sendFuture = apnSender.sendMessage(message); ListenableFuture<ApnResult> sendFuture = apnSender.sendMessage(message);
// ApnResult apnResult = sendFuture.get(); ApnResult apnResult = sendFuture.get();
//
// Thread.sleep(1000); // =( Thread.sleep(1000); // =(
//
// ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
// verify(apnsClient, times(1)).sendNotification(notification.capture()); verify(apnsClient, times(1)).sendNotification(notification.capture());
//
// assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
// assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
// assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
// assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
//
// assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER);
//
// verifyNoMoreInteractions(apnsClient); verifyNoMoreInteractions(apnsClient);
// verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER));
// verify(destinationAccount, times(1)).getDevice(1); verify(destinationAccount, times(1)).getDevice(1);
// verify(destinationDevice, times(1)).getApnId(); verify(destinationDevice, times(1)).getApnId();
// verify(destinationDevice, times(1)).getPushTimestamp(); verify(destinationDevice, times(1)).getPushTimestamp();
//
// verifyNoMoreInteractions(destinationDevice); verifyNoMoreInteractions(destinationDevice);
// verifyNoMoreInteractions(destinationAccount); verifyNoMoreInteractions(destinationAccount);
// verifyNoMoreInteractions(accountsManager); verifyNoMoreInteractions(accountsManager);
// verifyNoMoreInteractions(fallbackManager); verifyNoMoreInteractions(fallbackManager);
// } }
// @Test // @Test
// public void testUnregisteredUserOldApnId() throws Exception { // public void testUnregisteredUserOldApnId() throws Exception {
@ -343,7 +342,7 @@ public class APNSenderTest {
.thenReturn(result); .thenReturn(result);
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -354,8 +353,8 @@ public class APNSenderTest {
verify(apnsClient, times(1)).sendNotification(notification.capture()); verify(apnsClient, times(1)).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.GENERIC_FAILURE); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.GENERIC_FAILURE);
@ -384,7 +383,7 @@ public class APNSenderTest {
.thenReturn(connectedResult); .thenReturn(connectedResult);
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -409,8 +408,8 @@ public class APNSenderTest {
verify(apnsClient, times(1)).getReconnectionFuture(); verify(apnsClient, times(1)).getReconnectionFuture();
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.SUCCESS); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.SUCCESS);
@ -434,7 +433,7 @@ public class APNSenderTest {
.thenReturn(result); .thenReturn(result);
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 3); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 3);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager); apnSender.setApnFallbackManager(fallbackManager);
@ -451,8 +450,8 @@ public class APNSenderTest {
verify(apnsClient, times(4)).sendNotification(notification.capture()); verify(apnsClient, times(4)).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION));
assertThat(notification.getValue().getPayload()).isEqualTo("message"); assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
verifyNoMoreInteractions(apnsClient); verifyNoMoreInteractions(apnsClient);

View File

@ -1,78 +0,0 @@
package org.whispersystems.textsecuregcm.tests.push;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.List;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
public class ApnFallbackManagerTest {
@Test
public void testFullFallback() throws Exception {
APNSender apnSender = mock(APNSender.class );
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222223", 1L);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111);
ApnFallbackTask task = new ApnFallbackTask("foo", "voipfoo", message, 500, 0);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(apnSender, pubSubManager);
apnFallbackManager.start();
apnFallbackManager.schedule(address, task);
Util.sleep(1100);
ArgumentCaptor<ApnMessage> captor = ArgumentCaptor.forClass(ApnMessage.class);
verify(apnSender, times(2)).sendMessage(captor.capture());
verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager));
List<ApnMessage> arguments = captor.getAllValues();
assertEquals(arguments.get(0).getMessage(), message.getMessage());
assertEquals(arguments.get(0).getApnId(), task.getVoipApnId());
// assertEquals(arguments.get(0).getExpirationTime(), Integer.MAX_VALUE * 1000L);
assertEquals(arguments.get(1).getMessage(), message.getMessage());
assertEquals(arguments.get(1).getApnId(), task.getApnId());
assertEquals(arguments.get(1).getExpirationTime(), Integer.MAX_VALUE * 1000L);
}
@Test
public void testNoFallback() throws Exception {
APNSender pushServiceClient = mock(APNSender.class);
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222222", 1);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 5555);
ApnFallbackTask task = new ApnFallbackTask ("foo", "voipfoo", message, 500, 0);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
apnFallbackManager.start();
apnFallbackManager.schedule(address, task);
apnFallbackManager.onDispatchMessage(info.serialize(),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.CONNECTED)
.build().toByteArray());
verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager));
Util.sleep(1100);
verifyNoMoreInteractions(pushServiceClient);
}
}

View File

@ -1,92 +0,0 @@
package org.whispersystems.textsecuregcm.tests.push;
import org.junit.Test;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTaskQueue;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ApnFallbackTaskQueueTest {
@Test
public void testBlocking() {
final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
final WebsocketAddress address = mock(WebsocketAddress.class);
final ApnFallbackTask task = mock(ApnFallbackTask.class );
when(task.getExecutionTime()).thenReturn(System.currentTimeMillis() - 1000);
new Thread() {
@Override
public void run() {
Util.sleep(500);
taskQueue.put(address, task);
}
}.start();
Map.Entry<WebsocketAddress, ApnFallbackTask> result = taskQueue.get();
assertEquals(result.getKey(), address);
assertEquals(result.getValue(), task);
}
@Test
public void testElapsedTime() {
final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
final WebsocketAddress address = mock(WebsocketAddress.class);
final ApnFallbackTask task = mock(ApnFallbackTask.class );
long currentTime = System.currentTimeMillis();
when(task.getExecutionTime()).thenReturn(currentTime + 1000);
taskQueue.put(address, task);
Map.Entry<WebsocketAddress, ApnFallbackTask> result = taskQueue.get();
assertTrue(System.currentTimeMillis() >= currentTime + 1000);
assertEquals(result.getKey(), address);
assertEquals(result.getValue(), task);
}
@Test
public void testCanceled() {
final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
final WebsocketAddress addressOne = mock(WebsocketAddress.class);
final ApnFallbackTask taskOne = mock(ApnFallbackTask.class );
final WebsocketAddress addressTwo = mock(WebsocketAddress.class);
final ApnFallbackTask taskTwo = mock(ApnFallbackTask.class );
long currentTime = System.currentTimeMillis();
when(taskOne.getExecutionTime()).thenReturn(currentTime + 1000);
when(taskTwo.getExecutionTime()).thenReturn(currentTime + 2000);
taskQueue.put(addressOne, taskOne);
taskQueue.put(addressTwo, taskTwo);
new Thread() {
@Override
public void run() {
Util.sleep(300);
taskQueue.remove(addressOne);
}
}.start();
Map.Entry<WebsocketAddress, ApnFallbackTask> result = taskQueue.get();
assertTrue(System.currentTimeMillis() >= currentTime + 2000);
assertEquals(result.getKey(), addressTwo);
assertEquals(result.getValue(), taskTwo);
}
}

View File

@ -11,6 +11,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender; import org.whispersystems.textsecuregcm.push.WebsocketSender;
@ -58,12 +59,13 @@ public class WebSocketConnectionTest {
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
private static final PushSender pushSender = mock(PushSender.class); private static final PushSender pushSender = mock(PushSender.class);
private static final ReceiptSender receiptSender = mock(ReceiptSender.class); private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
@Test @Test
public void testCredentials() throws Exception { public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class); MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, storedMessages, pubSubManager); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@ -250,7 +252,7 @@ public class WebSocketConnectionTest {
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.<String>absent())); verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.<String>absent()));
verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class)); verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class));
verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(true)); verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device));
connection.onDispatchUnsubscribed(websocketAddress.serialize()); connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString()); verify(client).close(anyInt(), anyString());