diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 033232713..e213f1f4b 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -76,6 +76,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private RedisConfiguration directory; + @NotNull + @Valid + @JsonProperty + private RedisConfiguration pushScheduler; + @NotNull @Valid @JsonProperty @@ -170,6 +175,10 @@ public class WhisperServerConfiguration extends Configuration { return messageCache; } + public RedisConfiguration getPushScheduler() { + return pushScheduler; + } + public DataSourceFactory getMessageStoreConfiguration() { return messageStore; } diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 639f42964..57a592036 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -158,12 +158,15 @@ public class WhisperServerService extends Application() @@ -227,7 +230,7 @@ public class WhisperServerService extends Application apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get())); + } + return messagesManager.getMessagesForDevice(account.getNumber(), account.getAuthenticatedDevice().get().getId()); } @@ -219,7 +230,7 @@ public class MessageController { messageBuilder.setRelay(source.getRelay().get()); } - pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), incomingMessage.isSilent()); + pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build()); } catch (NotPushRegisteredException e) { if (destinationDevice.isMaster()) throw new NoSuchUserException(e); else logger.debug("Not registered", e); diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 1eb617d54..44d5eb313 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -17,7 +17,6 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; -import org.hibernate.validator.constraints.NotEmpty; public class IncomingMessage { @@ -45,10 +44,6 @@ public class IncomingMessage { @JsonProperty private long timestamp; // deprecated - @JsonProperty - private boolean silent = false; - - public String getDestination() { return destination; } @@ -77,7 +72,4 @@ public class IncomingMessage { return content; } - public boolean isSilent() { - return silent; - } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java index a1fb9a555..d7d2775f6 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/APNSender.java @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2013 Open WhisperSystems * * This program is free software: you can redistribute it and/or modify @@ -16,6 +16,9 @@ */ package org.whispersystems.textsecuregcm.push; +import com.codahale.metrics.Meter; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.util.concurrent.FutureCallback; @@ -25,10 +28,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult; +import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; +import org.whispersystems.textsecuregcm.util.Constants; import javax.annotation.Nullable; import java.io.IOException; @@ -37,12 +41,17 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import static com.codahale.metrics.MetricRegistry.name; import io.dropwizard.lifecycle.Managed; public class APNSender implements Managed { private final Logger logger = LoggerFactory.getLogger(APNSender.class); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Meter unregisteredEventStale = metricRegistry.meter(name(APNSender.class, "unregistered_event_stale")); + private static final Meter unregisteredEventFresh = metricRegistry.meter(name(APNSender.class, "unregistered_event_fresh")); + private ExecutorService executor; private ApnFallbackManager fallbackManager; @@ -71,9 +80,7 @@ public class APNSender implements Managed { this.bundleId = bundleId; } - public ListenableFuture sendMessage(final ApnMessage message) - throws TransientPushFailureException - { + public ListenableFuture sendMessage(final ApnMessage message) { String topic = bundleId; if (message.isVoip()) { @@ -106,13 +113,13 @@ public class APNSender implements Managed { } @Override - public void start() throws Exception { + public void start() { this.executor = Executors.newSingleThreadExecutor(); this.apnsClient.connect(sandbox); } @Override - public void stop() throws Exception { + public void stop() { this.executor.shutdown(); this.apnsClient.disconnect(); } @@ -121,13 +128,14 @@ public class APNSender implements Managed { this.fallbackManager = fallbackManager; } - private void handleUnregisteredUser(String registrationId, String number, int deviceId) { - logger.info("Got APN Unregistered: " + number + "," + deviceId); + private void handleUnregisteredUser(String registrationId, String number, long deviceId) { +// logger.info("Got APN Unregistered: " + number + "," + deviceId); Optional account = accountsManager.get(number); if (!account.isPresent()) { logger.info("No account found: " + number); + unregisteredEventStale.mark(); return; } @@ -135,6 +143,7 @@ public class APNSender implements Managed { if (!device.isPresent()) { logger.info("No device found: " + number); + unregisteredEventStale.mark(); return; } @@ -142,24 +151,26 @@ public class APNSender implements Managed { !registrationId.equals(device.get().getVoipApnId())) { logger.info("Registration ID does not match: " + registrationId + ", " + device.get().getApnId() + ", " + device.get().getVoipApnId()); + unregisteredEventStale.mark(); return; } - if (registrationId.equals(device.get().getApnId())) { - logger.info("APN Unregister APN ID matches! " + number + ", " + deviceId); - } else if (registrationId.equals(device.get().getVoipApnId())) { - logger.info("APN Unregister VoIP ID matches! " + number + ", " + deviceId); - } +// if (registrationId.equals(device.get().getApnId())) { +// logger.info("APN Unregister APN ID matches! " + number + ", " + deviceId); +// } else if (registrationId.equals(device.get().getVoipApnId())) { +// logger.info("APN Unregister VoIP ID matches! " + number + ", " + deviceId); +// } long tokenTimestamp = device.get().getPushTimestamp(); if (tokenTimestamp != 0 && System.currentTimeMillis() < tokenTimestamp + TimeUnit.SECONDS.toMillis(10)) { logger.info("APN Unregister push timestamp is more recent: " + tokenTimestamp + ", " + number); + unregisteredEventStale.mark(); return; } - logger.info("APN Unregister timestamp matches: " + device.get().getApnId() + ", " + device.get().getVoipApnId()); +// logger.info("APN Unregister timestamp matches: " + device.get().getApnId() + ", " + device.get().getVoipApnId()); // device.get().setApnId(null); // device.get().setVoipApnId(null); // device.get().setFetchesMessages(false); @@ -168,5 +179,10 @@ public class APNSender implements Managed { // if (fallbackManager != null) { // fallbackManager.cancel(new WebsocketAddress(number, deviceId)); // } + + if (fallbackManager != null) { + RedisOperation.unchecked(() -> fallbackManager.cancel(account.get(), device.get())); + unregisteredEventFresh.mark(); + } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java b/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java index 3fde9ba30..929f06c75 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java @@ -1,237 +1,247 @@ package org.whispersystems.textsecuregcm.push; -import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.RatioGauge; import com.codahale.metrics.SharedMetricRegistries; -import com.google.common.annotations.VisibleForTesting; -import com.google.protobuf.InvalidProtocolBufferException; +import com.google.common.base.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.dispatch.DispatchChannel; -import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; +import org.whispersystems.textsecuregcm.redis.LuaScript; +import org.whispersystems.textsecuregcm.redis.RedisException; +import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.Map.Entry; -import java.util.concurrent.TimeUnit; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import static com.codahale.metrics.MetricRegistry.name; import io.dropwizard.lifecycle.Managed; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.exceptions.JedisException; -public class ApnFallbackManager implements Managed, Runnable, DispatchChannel { +@SuppressWarnings("Guava") +public class ApnFallbackManager implements Managed, Runnable { private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class); - public static final int FALLBACK_DURATION = 15; + private static final String PENDING_NOTIFICATIONS_KEY = "PENDING_APN"; - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private static final Meter voipOneSuccess = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_success")); - private static final Meter voipOneDelivery = metricRegistry.meter(name(ApnFallbackManager.class, "voip_one_failure")); - private static final Histogram voipOneSuccessHistogram = metricRegistry.histogram(name(ApnFallbackManager.class, "voip_one_success_histogram")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Meter delivered = metricRegistry.meter(name(ApnFallbackManager.class, "voip_delivered")); + private static final Meter sent = metricRegistry.meter(name(ApnFallbackManager.class, "voip_sent" )); + private static final Meter retry = metricRegistry.meter(name(ApnFallbackManager.class, "voip_retry")); static { - metricRegistry.register(name(ApnFallbackManager.class, "voip_one_success_ratio"), new VoipRatioGauge(voipOneSuccess, voipOneDelivery)); + metricRegistry.register(name(ApnFallbackManager.class, "voip_ratio"), new VoipRatioGauge(delivered, sent)); } - private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); + private final APNSender apnSender; + private final AccountsManager accountsManager; - private final APNSender apnSender; - private final PubSubManager pubSubManager; + private final ReplicatedJedisPool jedisPool; + private final InsertOperation insertOperation; + private final GetOperation getOperation; + private final RemoveOperation removeOperation; - public ApnFallbackManager(APNSender apnSender, PubSubManager pubSubManager) { - this.apnSender = apnSender; - this.pubSubManager = pubSubManager; + + private AtomicBoolean running = new AtomicBoolean(false); + private boolean finished; + + public ApnFallbackManager(ReplicatedJedisPool jedisPool, + APNSender apnSender, + AccountsManager accountsManager) + throws IOException + { + this.apnSender = apnSender; + this.accountsManager = accountsManager; + this.jedisPool = jedisPool; + this.insertOperation = new InsertOperation(jedisPool); + this.getOperation = new GetOperation(jedisPool); + this.removeOperation = new RemoveOperation(jedisPool); } - public void schedule(final WebsocketAddress address, ApnFallbackTask task) { - voipOneDelivery.mark(); - - if (taskQueue.put(address, task)) { - pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); + public void schedule(Account account, Device device) throws RedisException { + try { + sent.mark(); + insertOperation.insert(account, device, System.currentTimeMillis() + (15 * 1000), (15 * 1000)); + } catch (JedisException e) { + throw new RedisException(e); } } - private void scheduleRetry(final WebsocketAddress address, ApnFallbackTask task) { - if (taskQueue.putIfMissing(address, task)) { - pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); + public boolean isScheduled(Account account, Device device) throws RedisException { + try { + String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId(); + + try (Jedis jedis = jedisPool.getReadResource()) { + return jedis.zscore(PENDING_NOTIFICATIONS_KEY, endpoint) != null; + } + } catch (JedisException e) { + throw new RedisException(e); } } - public void cancel(WebsocketAddress address) { - ApnFallbackTask task = taskQueue.remove(address); - - if (task != null) { - pubSubManager.unsubscribe(new WebSocketConnectionInfo(address), this); - voipOneSuccess.mark(); - voipOneSuccessHistogram.update(System.currentTimeMillis() - task.getScheduledTime()); + public void cancel(Account account, Device device) throws RedisException { + try { + if (removeOperation.remove(account, device)) { + delivered.mark(); + } + } catch (JedisException e) { + throw new RedisException(e); } } @Override - public void start() throws Exception { + public synchronized void start() { + running.set(true); new Thread(this).start(); } @Override - public void stop() throws Exception { - + public synchronized void stop() { + running.set(false); + while (!finished) Util.wait(this); } @Override public void run() { - while (true) { + while (running.get()) { try { - Entry taskEntry = taskQueue.get(); - ApnFallbackTask task = taskEntry.getValue(); + List pendingNotifications = getOperation.getPending(100); - ApnMessage message; + for (byte[] pendingNotification : pendingNotifications) { + String numberAndDevice = new String(pendingNotification); + Optional> separated = getSeparated(numberAndDevice); - if (task.getAttempt() == 0) { - message = new ApnMessage(task.getMessage(), task.getVoipApnId(), true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(FALLBACK_DURATION)); - scheduleRetry(taskEntry.getKey(), new ApnFallbackTask(task.getApnId(), task.getVoipApnId(), task.getMessage(), task.getDelay(),1)); - } else { - message = new ApnMessage(task.getMessage(), task.getApnId(), false, ApnMessage.MAX_EXPIRATION); - pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this); - } - - apnSender.sendMessage(message); - } catch (Throwable e) { - logger.warn("ApnFallbackThread", e); - } - } - } - - @Override - public void onDispatchMessage(String channel, byte[] message) { - try { - PubSubMessage notification = PubSubMessage.parseFrom(message); - - if (notification.getType().getNumber() == PubSubMessage.Type.CONNECTED_VALUE) { - WebSocketConnectionInfo address = new WebSocketConnectionInfo(channel); - cancel(address.getWebsocketAddress()); - } else { - logger.warn("Got strange pubsub type: " + notification.getType().getNumber()); - } - - } catch (WebSocketConnectionInfo.FormattingException e) { - logger.warn("Bad formatting?", e); - } catch (InvalidProtocolBufferException e) { - logger.warn("Bad protobuf", e); - } - } - - @Override - public void onDispatchSubscribed(String channel) {} - - @Override - public void onDispatchUnsubscribed(String channel) {} - - public static class ApnFallbackTask { - - private final long delay; - private final long scheduledTime; - private final String apnId; - private final String voipApnId; - private final ApnMessage message; - private final int attempt; - - public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message) { - this(apnId, voipApnId, message, TimeUnit.SECONDS.toMillis(FALLBACK_DURATION), 0); - } - - @VisibleForTesting - public ApnFallbackTask(String apnId, String voipApnId, ApnMessage message, long delay, int attempt) { - this.scheduledTime = System.currentTimeMillis(); - this.delay = delay; - this.apnId = apnId; - this.voipApnId = voipApnId; - this.message = message; - this.attempt = attempt; - } - - public String getApnId() { - return apnId; - } - - public String getVoipApnId() { - return voipApnId; - } - - public ApnMessage getMessage() { - return message; - } - - public long getScheduledTime() { - return scheduledTime; - } - - public long getExecutionTime() { - return scheduledTime + delay; - } - - public long getDelay() { - return delay; - } - - public int getAttempt() { - return attempt; - } - } - - @VisibleForTesting - public static class ApnFallbackTaskQueue { - - private final LinkedHashMap tasks = new LinkedHashMap<>(); - - public Entry get() { - while (true) { - long timeDelta; - - synchronized (tasks) { - while (tasks.isEmpty()) Util.wait(tasks); - - Iterator> iterator = tasks.entrySet().iterator(); - Entry nextTask = iterator.next(); - - timeDelta = nextTask.getValue().getExecutionTime() - System.currentTimeMillis(); - - if (timeDelta <= 0) { - iterator.remove(); - return nextTask; + if (!separated.isPresent()) { + removeOperation.remove(numberAndDevice); + continue; } + + Optional account = accountsManager.get(separated.get().first()); + + if (!account.isPresent()) { + removeOperation.remove(numberAndDevice); + continue; + } + + Optional device = account.get().getDevice(separated.get().second()); + + if (!device.isPresent()) { + removeOperation.remove(numberAndDevice); + continue; + } + + String apnId = device.get().getVoipApnId(); + + if (apnId == null) { + removeOperation.remove(account.get(), device.get()); + continue; + } + + apnSender.sendMessage(new ApnMessage(apnId, separated.get().first(), separated.get().second(), true)); + retry.mark(); } - Util.sleep(timeDelta); + } catch (Exception e) { + logger.warn("Exception while operating", e); } + + Util.sleep(1000); } - public boolean put(WebsocketAddress address, ApnFallbackTask task) { - synchronized (tasks) { - ApnFallbackTask previous = tasks.put(address, task); - tasks.notifyAll(); + synchronized (ApnFallbackManager.this) { + finished = true; + notifyAll(); + } + } - return previous == null; + private Optional> getSeparated(String encoded) { + try { + if (encoded == null) return Optional.absent(); + + String[] parts = encoded.split(":"); + + if (parts.length != 2) { + logger.warn("Got strange encoded number: " + encoded); + return Optional.absent(); } + + return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1]))); + } catch (NumberFormatException e) { + logger.warn("Badly formatted: " + encoded, e); + return Optional.absent(); + } + } + + private static class RemoveOperation { + + private final LuaScript luaScript; + + RemoveOperation(ReplicatedJedisPool jedisPool) throws IOException { + this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/remove.lua"); } - public boolean putIfMissing(WebsocketAddress address, ApnFallbackTask task) { - synchronized (tasks) { - if (tasks.containsKey(address)) return false; - return put(address, task); - } + boolean remove(Account account, Device device) { + String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId(); + return remove(endpoint); } - public ApnFallbackTask remove(WebsocketAddress address) { - synchronized (tasks) { - return tasks.remove(address); + boolean remove(String endpoint) { + if (!PENDING_NOTIFICATIONS_KEY.equals(endpoint)) { + List keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes()); + List args = Collections.emptyList(); + + return ((long)luaScript.execute(keys, args)) > 0; } + + return false; + } + + } + + private static class GetOperation { + + private final LuaScript luaScript; + + GetOperation(ReplicatedJedisPool jedisPool) throws IOException { + this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/get.lua"); + } + + @SuppressWarnings("SameParameterValue") + List getPending(int limit) { + List keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes()); + List args = Arrays.asList(String.valueOf(System.currentTimeMillis()).getBytes(), String.valueOf(limit).getBytes()); + + return (List) luaScript.execute(keys, args); + } + } + + private static class InsertOperation { + + private final LuaScript luaScript; + + InsertOperation(ReplicatedJedisPool jedisPool) throws IOException { + this.luaScript = LuaScript.fromResource(jedisPool, "lua/apn/insert.lua"); + } + + public void insert(Account account, Device device, long timestamp, long interval) { + String endpoint = "apn_device::" + account.getNumber() + "::" + device.getId(); + + List keys = Arrays.asList(PENDING_NOTIFICATIONS_KEY.getBytes(), endpoint.getBytes()); + List args = Arrays.asList(String.valueOf(timestamp).getBytes(), String.valueOf(interval).getBytes(), + account.getNumber().getBytes(), String.valueOf(device.getId()).getBytes()); + + luaScript.execute(keys, args); } } @@ -247,7 +257,7 @@ public class ApnFallbackManager implements Managed, Runnable, DispatchChannel { @Override protected Ratio getRatio() { - return Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate()); + return RatioGauge.Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/ApnMessage.java b/src/main/java/org/whispersystems/textsecuregcm/push/ApnMessage.java index 972e5ff62..c87b372e4 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/ApnMessage.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/ApnMessage.java @@ -2,33 +2,21 @@ package org.whispersystems.textsecuregcm.push; public class ApnMessage { - public static long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L; + public static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}"; + public static final long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L; private final String apnId; private final String number; - private final int deviceId; - private final String message; + private final long deviceId; private final boolean isVoip; - private final long expirationTime; - public ApnMessage(String apnId, String number, int deviceId, String message, boolean isVoip, long expirationTime) { + public ApnMessage(String apnId, String number, long deviceId, boolean isVoip) { this.apnId = apnId; this.number = number; this.deviceId = deviceId; - this.message = message; this.isVoip = isVoip; - this.expirationTime = expirationTime; } - - public ApnMessage(ApnMessage copy, String apnId, boolean isVoip, long expirationTime) { - this.apnId = apnId; - this.number = copy.number; - this.deviceId = copy.deviceId; - this.message = copy.message; - this.isVoip = isVoip; - this.expirationTime = expirationTime; - } - + public boolean isVoip() { return isVoip; } @@ -38,18 +26,18 @@ public class ApnMessage { } public String getMessage() { - return message; + return APN_PAYLOAD; } public long getExpirationTime() { - return expirationTime; + return MAX_EXPIRATION; } public String getNumber() { return number; } - public int getDeviceId() { + public long getDeviceId() { return deviceId; } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index 3d22227c3..aa2e1c5be 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -20,14 +20,13 @@ import com.codahale.metrics.Gauge; import com.codahale.metrics.SharedMetricRegistries; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus; +import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.concurrent.TimeUnit; @@ -37,10 +36,9 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; public class PushSender implements Managed { + @SuppressWarnings("unused") private final Logger logger = LoggerFactory.getLogger(PushSender.class); - private static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}"; - private final ApnFallbackManager apnFallbackManager; private final GCMSender gcmSender; private final APNSender apnSender; @@ -64,7 +62,7 @@ public class PushSender implements Managed { (Gauge) executor::getSize); } - public void sendMessage(final Account account, final Device device, final Envelope message, final boolean silent) + public void sendMessage(final Account account, final Device device, final Envelope message) throws NotPushRegisteredException { if (device.getGcmId() == null && device.getApnId() == null && !device.getFetchesMessages()) { @@ -72,17 +70,17 @@ public class PushSender implements Managed { } if (queueSize > 0) { - executor.execute(() -> sendSynchronousMessage(account, device, message, silent)); + executor.execute(() -> sendSynchronousMessage(account, device, message)); } else { - sendSynchronousMessage(account, device, message, silent); + sendSynchronousMessage(account, device, message); } } - public void sendQueuedNotification(Account account, Device device, boolean fallback) - throws NotPushRegisteredException, TransientPushFailureException + public void sendQueuedNotification(Account account, Device device) + throws NotPushRegisteredException { if (device.getGcmId() != null) sendGcmNotification(account, device); - else if (device.getApnId() != null) sendApnNotification(account, device, fallback); + else if (device.getApnId() != null) sendApnNotification(account, device, true); else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!"); } @@ -90,9 +88,9 @@ public class PushSender implements Managed { return webSocketSender; } - private void sendSynchronousMessage(Account account, Device device, Envelope message, boolean silent) { + private void sendSynchronousMessage(Account account, Device device, Envelope message) { if (device.getGcmId() != null) sendGcmMessage(account, device, message); - else if (device.getApnId() != null) sendApnMessage(account, device, message, silent); + else if (device.getApnId() != null) sendApnMessage(account, device, message); else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); else throw new AssertionError(); } @@ -112,36 +110,29 @@ public class PushSender implements Managed { gcmSender.sendMessage(gcmMessage); } - private void sendApnMessage(Account account, Device device, Envelope outgoingMessage, boolean silent) { + private void sendApnMessage(Account account, Device device, Envelope outgoingMessage) { DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN); if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) { - boolean fallback = !silent && !outgoingMessage.getSource().equals(account.getNumber()); - sendApnNotification(account, device, fallback); + sendApnNotification(account, device, false); } } - private void sendApnNotification(Account account, Device device, boolean fallback) { + private void sendApnNotification(Account account, Device device, boolean newOnly) { ApnMessage apnMessage; + if (newOnly && RedisOperation.unchecked(() -> apnFallbackManager.isScheduled(account, device))) { + return; + } + if (!Util.isEmpty(device.getVoipApnId())) { - apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, true, - System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(ApnFallbackManager.FALLBACK_DURATION)); - - if (fallback) { - apnFallbackManager.schedule(new WebsocketAddress(account.getNumber(), device.getId()), - new ApnFallbackTask(device.getApnId(), device.getVoipApnId(), apnMessage)); - } + apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true); + RedisOperation.unchecked(() -> apnFallbackManager.schedule(account, device)); } else { - apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), APN_PAYLOAD, - false, ApnMessage.MAX_EXPIRATION); + apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false); } - try { - apnSender.sendMessage(apnMessage); - } catch (TransientPushFailureException e) { - logger.warn("SILENT PUSH LOSS", e); - } + apnSender.sendMessage(apnMessage); } private void sendWebSocketMessage(Account account, Device device, Envelope outgoingMessage) @@ -163,4 +154,5 @@ public class PushSender implements Managed { apnSender.stop(); gcmSender.stop(); } + } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index 84cb576a2..f87ee86e5 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -72,7 +72,7 @@ public class ReceiptSender { } for (Device destinationDevice : destinationDevices) { - pushSender.sendMessage(destinationAccount, destinationDevice, message.build(), true); + pushSender.sendMessage(destinationAccount, destinationDevice, message.build()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/redis/RedisException.java b/src/main/java/org/whispersystems/textsecuregcm/redis/RedisException.java new file mode 100644 index 000000000..c45cefbed --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/redis/RedisException.java @@ -0,0 +1,8 @@ +package org.whispersystems.textsecuregcm.redis; + +public class RedisException extends Exception { + + public RedisException(Exception e) { + super(e); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/redis/RedisOperation.java b/src/main/java/org/whispersystems/textsecuregcm/redis/RedisOperation.java new file mode 100644 index 000000000..47f6255a8 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/redis/RedisOperation.java @@ -0,0 +1,37 @@ +package org.whispersystems.textsecuregcm.redis; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.push.PushSender; + +public class RedisOperation { + + private static final Logger logger = LoggerFactory.getLogger(RedisOperation.class); + + public static void unchecked(Operation operation) { + try { + operation.run(); + } catch (RedisException e) { + logger.warn("Jedis failure", e); + } + } + + public static boolean unchecked(BooleanOperation operation) { + try { + return operation.run(); + } catch (RedisException e) { + logger.warn("Jedis failure", e); + } + + return false; + } + + @FunctionalInterface + public interface Operation { + public void run() throws RedisException; + } + + public interface BooleanOperation { + public boolean run() throws RedisException; + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 459ce4ab8..64e15dc14 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -12,7 +12,6 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushSender; -import org.whispersystems.textsecuregcm.push.TransientPushFailureException; import org.whispersystems.textsecuregcm.redis.LuaScript; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.util.Constants; @@ -483,11 +482,9 @@ public class MessagesCache implements Managed { if (device.isPresent()) { try { - pushSender.sendQueuedNotification(account.get(), device.get(), false); + pushSender.sendQueuedNotification(account.get(), device.get()); } catch (NotPushRegisteredException e) { logger.warn("After message persistence, no longer push registered!"); - } catch (TransientPushFailureException e) { - logger.warn("Transient push failure!", e); } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index ceca0bdb3..b2864f696 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -6,16 +6,16 @@ import com.codahale.metrics.Timer; import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; +import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.util.Constants; -import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; @@ -29,21 +29,23 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final Timer durationTimer = metricRegistry.timer(name(WebSocketConnection.class, "connected_duration")); - private final AccountsManager accountsManager; private final PushSender pushSender; private final ReceiptSender receiptSender; private final MessagesManager messagesManager; private final PubSubManager pubSubManager; + private final ApnFallbackManager apnFallbackManager; - public AuthenticatedConnectListener(AccountsManager accountsManager, PushSender pushSender, - ReceiptSender receiptSender, MessagesManager messagesManager, - PubSubManager pubSubManager) + public AuthenticatedConnectListener(PushSender pushSender, + ReceiptSender receiptSender, + MessagesManager messagesManager, + PubSubManager pubSubManager, + ApnFallbackManager apnFallbackManager) { - this.accountsManager = accountsManager; - this.pushSender = pushSender; - this.receiptSender = receiptSender; - this.messagesManager = messagesManager; - this.pubSubManager = pubSubManager; + this.pushSender = pushSender; + this.receiptSender = receiptSender; + this.messagesManager = messagesManager; + this.pubSubManager = pubSubManager; + this.apnFallbackManager = apnFallbackManager; } @Override @@ -53,15 +55,14 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final String connectionId = String.valueOf(new SecureRandom().nextLong()); final Timer.Context timer = durationTimer.time(); final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); - final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, - messagesManager, account, device, - context.getClient(), connectionId); + messagesManager, account, device, + context.getClient(), connectionId); final PubSubMessage connectMessage = PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED) .setContent(ByteString.copyFrom(connectionId.getBytes())) .build(); - pubSubManager.publish(info, connectMessage); + RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); pubSubManager.publish(address, connectMessage); pubSubManager.subscribe(address, connection); diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java index a22152356..a82f0e583 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java @@ -20,24 +20,22 @@ public class DeadLetterHandler implements DispatchChannel { @Override public void onDispatchMessage(String channel, byte[] data) { - if (!WebSocketConnectionInfo.isType(channel)) { - try { - logger.info("Handling dead letter to: " + channel); + try { + logger.info("Handling dead letter to: " + channel); - WebsocketAddress address = new WebsocketAddress(channel); - PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data); + WebsocketAddress address = new WebsocketAddress(channel); + PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data); - switch (pubSubMessage.getType().getNumber()) { - case PubSubMessage.Type.DELIVER_VALUE: - Envelope message = Envelope.parseFrom(pubSubMessage.getContent()); - messagesManager.insert(address.getNumber(), address.getDeviceId(), message); - break; - } - } catch (InvalidProtocolBufferException e) { - logger.warn("Bad pubsub message", e); - } catch (InvalidWebsocketAddressException e) { - logger.warn("Invalid websocket address", e); + switch (pubSubMessage.getType().getNumber()) { + case PubSubMessage.Type.DELIVER_VALUE: + Envelope message = Envelope.parseFrom(pubSubMessage.getContent()); + messagesManager.insert(address.getNumber(), address.getDeviceId(), message); + break; } + } catch (InvalidProtocolBufferException e) { + logger.warn("Bad pubsub message", e); + } catch (InvalidWebsocketAddressException e) { + logger.warn("Invalid websocket address", e); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 723fd11b3..11313393e 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -150,11 +150,9 @@ public class WebSocketConnection implements DispatchChannel { private void requeueMessage(Envelope message) { pushSender.getWebSocketSender().queueMessage(account, device, message); - boolean fallback = !message.getSource().equals(account.getNumber()) && message.getType() != Envelope.Type.RECEIPT; - try { - pushSender.sendQueuedNotification(account, device, fallback); - } catch (NotPushRegisteredException | TransientPushFailureException e) { + pushSender.sendQueuedNotification(account, device); + } catch (NotPushRegisteredException e) { logger.warn("requeueMessage", e); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java deleted file mode 100644 index 0683672ee..000000000 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java +++ /dev/null @@ -1,62 +0,0 @@ -package org.whispersystems.textsecuregcm.websocket; - -import org.whispersystems.textsecuregcm.storage.PubSubAddress; -import org.whispersystems.textsecuregcm.util.Util; - -public class WebSocketConnectionInfo implements PubSubAddress { - - private final WebsocketAddress address; - - public WebSocketConnectionInfo(WebsocketAddress address) { - this.address = address; - } - - public WebSocketConnectionInfo(String serialized) throws FormattingException { - String[] parts = serialized.split("[:]", 3); - - if (parts.length != 3 || !"c".equals(parts[2])) { - throw new FormattingException("Bad address: " + serialized); - } - - try { - this.address = new WebsocketAddress(parts[0], Long.parseLong(parts[1])); - } catch (NumberFormatException e) { - throw new FormattingException(e); - } - } - - public String serialize() { - return address.serialize() + ":c"; - } - - public WebsocketAddress getWebsocketAddress() { - return address; - } - - public static boolean isType(String address) { - return address.endsWith(":c"); - } - - @Override - public boolean equals(Object other) { - return - other != null && - other instanceof WebSocketConnectionInfo - && ((WebSocketConnectionInfo)other).address.equals(address); - } - - @Override - public int hashCode() { - return Util.hashCode(address, "c"); - } - - public static class FormattingException extends Exception { - public FormattingException(String message) { - super(message); - } - - public FormattingException(Exception e) { - super(e); - } - } -} diff --git a/src/main/resources/lua/apn/get.lua b/src/main/resources/lua/apn/get.lua new file mode 100644 index 000000000..7c79e97f5 --- /dev/null +++ b/src/main/resources/lua/apn/get.lua @@ -0,0 +1,66 @@ +-- keys: pending (KEYS[1]) +-- argv: max_time (ARGV[1]), limit (ARGV[2]) + +local hgetall = function (key) + local bulk = redis.call('HGETALL', key) + local result = {} + local nextkey + for i, v in ipairs(bulk) do + if i % 2 == 1 then + nextkey = v + else + result[nextkey] = v + end + end + return result +end + +local getNextInterval = function(interval) + if interval < 20000 then + return 20000 + end + + if interval < 40000 then + return 40000 + end + + if interval < 80000 then + return 80000 + end + + if interval < 160000 then + return 160000 + end + + if interval < 600000 then + return 600000 + end + + return 1800000 +end + + +local results = redis.call("ZRANGEBYSCORE", KEYS[1], 0, ARGV[1], "LIMIT", 0, ARGV[2]) +local collated = {} + +if results and next(results) then + for i, name in ipairs(results) do + local pending = hgetall(name) + local lastInterval = pending["interval"] + + if lastInterval == nil then + lastInterval = 0 + end + + local nextInterval = getNextInterval(tonumber(lastInterval)) + + redis.call("HSET", name, "interval", nextInterval) + redis.call("ZADD", KEYS[1], tonumber(ARGV[1]) + nextInterval, name) + + collated[i] = pending["account"] .. ":" .. pending["device"] + end +end + +return collated + + diff --git a/src/main/resources/lua/apn/insert.lua b/src/main/resources/lua/apn/insert.lua new file mode 100644 index 000000000..e9ac5971f --- /dev/null +++ b/src/main/resources/lua/apn/insert.lua @@ -0,0 +1,8 @@ +-- keys: pending (KEYS[1]), user (KEYS[2]) +-- args: timestamp (ARGV[1]), interval (ARGV[2]), account (ARGV[3]), device (ARGV[4]) + +redis.call("HSET", KEYS[2], "interval", ARGV[2]) +redis.call("HSET", KEYS[2], "account", ARGV[3]) +redis.call("HSET", KEYS[2], "device", ARGV[4]) + +redis.call("ZADD", KEYS[1], ARGV[1], KEYS[2]) diff --git a/src/main/resources/lua/apn/remove.lua b/src/main/resources/lua/apn/remove.lua new file mode 100644 index 000000000..2fac20596 --- /dev/null +++ b/src/main/resources/lua/apn/remove.lua @@ -0,0 +1,4 @@ +-- keys: queue KEYS[1], endpoint (KEYS[2]) + +redis.call("DEL", KEYS[2]) +return redis.call("ZREM", KEYS[1], KEYS[2]) diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java index 6f9bdca64..fe2516d12 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java @@ -20,6 +20,7 @@ import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; @@ -56,13 +57,14 @@ public class FederatedControllerTest { private MessagesManager messagesManager = mock(MessagesManager.class); private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiter rateLimiter = mock(RateLimiter.class ); + private ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); private final SignedPreKey signedPreKey = new SignedPreKey(3333, "foo", "baar"); private final PreKeyResponse preKeyResponseV2 = new PreKeyResponse("foo", new LinkedList()); private final ObjectMapper mapper = new ObjectMapper(); - private final MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager); + private final MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager, apnFallbackManager); private final KeysController keysControllerV2 = mock(KeysController.class); @Rule @@ -112,7 +114,7 @@ public class FederatedControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(204))); - verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.Envelope.class), eq(false)); + verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.Envelope.class)); } @Test diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index efa222821..8261e2928 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -18,6 +18,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; @@ -58,6 +59,7 @@ public class MessageControllerTest { private final MessagesManager messagesManager = mock(MessagesManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class ); private final RateLimiter rateLimiter = mock(RateLimiter.class ); + private final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); private final ObjectMapper mapper = new ObjectMapper(); @@ -67,7 +69,7 @@ public class MessageControllerTest { .addProvider(new AuthValueFactoryProvider.Binder()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, - messagesManager, federatedClientManager)) + messagesManager, federatedClientManager, apnFallbackManager)) .build(); @@ -104,7 +106,7 @@ public class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); - verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); + verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class)); } @Test @@ -157,7 +159,7 @@ public class MessageControllerTest { assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); - verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class), eq(false)); + verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(Envelope.class)); } @Test diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/APNSenderTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/APNSenderTest.java index ed5b9bcab..050854448 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/push/APNSenderTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/push/APNSenderTest.java @@ -20,7 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.Date; import java.util.concurrent.TimeUnit; @@ -64,7 +63,7 @@ public class APNSenderTest { .thenReturn(result); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); - ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); apnSender.setApnFallbackManager(fallbackManager); @@ -75,8 +74,8 @@ public class APNSenderTest { verify(apnsClient, times(1)).sendNotification(notification.capture()); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); - assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); - assertThat(notification.getValue().getPayload()).isEqualTo("message"); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip"); @@ -101,7 +100,7 @@ public class APNSenderTest { .thenReturn(result); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); - ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", false, 30); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); apnSender.setApnFallbackManager(fallbackManager); @@ -112,8 +111,8 @@ public class APNSenderTest { verify(apnsClient, times(1)).sendNotification(notification.capture()); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); - assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); - assertThat(notification.getValue().getPayload()).isEqualTo("message"); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(notification.getValue().getTopic()).isEqualTo("foo"); @@ -124,57 +123,57 @@ public class APNSenderTest { verifyNoMoreInteractions(fallbackManager); } -// @Test -// public void testUnregisteredUser() throws Exception { -// ApnsClient apnsClient = mock(ApnsClient.class); -// -// PushNotificationResponse response = mock(PushNotificationResponse.class); -// when(response.isAccepted()).thenReturn(false); -// when(response.getRejectionReason()).thenReturn("Unregistered"); -// -// DefaultPromise> result = new DefaultPromise<>(executor); -// result.setSuccess(response); -// -// when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) -// .thenReturn(result); -// -// RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); -// ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); -// APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); -// apnSender.setApnFallbackManager(fallbackManager); -// -// when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); -// when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(11)); -// -// ListenableFuture sendFuture = apnSender.sendMessage(message); -// ApnResult apnResult = sendFuture.get(); -// -// Thread.sleep(1000); // =( -// -// ArgumentCaptor notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); -// verify(apnsClient, times(1)).sendNotification(notification.capture()); -// -// assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); -// assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); -// assertThat(notification.getValue().getPayload()).isEqualTo("message"); -// assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); -// -// assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); -// -// verifyNoMoreInteractions(apnsClient); -// verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); -// verify(destinationAccount, times(1)).getDevice(1); -// verify(destinationDevice, times(1)).getApnId(); -// verify(destinationDevice, times(1)).getPushTimestamp(); + @Test + public void testUnregisteredUser() throws Exception { + ApnsClient apnsClient = mock(ApnsClient.class); + + PushNotificationResponse response = mock(PushNotificationResponse.class); + when(response.isAccepted()).thenReturn(false); + when(response.getRejectionReason()).thenReturn("Unregistered"); + + DefaultPromise> result = new DefaultPromise<>(executor); + result.setSuccess(response); + + when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) + .thenReturn(result); + + RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); + APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); + apnSender.setApnFallbackManager(fallbackManager); + + when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); + when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(11)); + + ListenableFuture sendFuture = apnSender.sendMessage(message); + ApnResult apnResult = sendFuture.get(); + + Thread.sleep(1000); // =( + + ArgumentCaptor notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); + verify(apnsClient, times(1)).sendNotification(notification.capture()); + + assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); + assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); + + assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); + + verifyNoMoreInteractions(apnsClient); + verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); + verify(destinationAccount, times(1)).getDevice(1); + verify(destinationDevice, times(1)).getApnId(); + verify(destinationDevice, times(1)).getPushTimestamp(); // verify(destinationDevice, times(1)).setApnId(eq((String)null)); // verify(destinationDevice, times(1)).setVoipApnId(eq((String)null)); // verify(destinationDevice, times(1)).setFetchesMessages(eq(false)); // verify(accountsManager, times(1)).update(eq(destinationAccount)); -// verify(fallbackManager, times(1)).cancel(eq(new WebsocketAddress(DESTINATION_NUMBER, 1))); -// -// verifyNoMoreInteractions(accountsManager); -// verifyNoMoreInteractions(fallbackManager); -// } + verify(fallbackManager, times(1)).cancel(eq(destinationAccount), eq(destinationDevice)); + + verifyNoMoreInteractions(accountsManager); + verifyNoMoreInteractions(fallbackManager); + } // @Test // public void testVoipUnregisteredUser() throws Exception { @@ -230,54 +229,54 @@ public class APNSenderTest { // verifyNoMoreInteractions(fallbackManager); // } -// @Test -// public void testRecentUnregisteredUser() throws Exception { -// ApnsClient apnsClient = mock(ApnsClient.class); -// -// PushNotificationResponse response = mock(PushNotificationResponse.class); -// when(response.isAccepted()).thenReturn(false); -// when(response.getRejectionReason()).thenReturn("Unregistered"); -// -// DefaultPromise> result = new DefaultPromise<>(executor); -// result.setSuccess(response); -// -// when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) -// .thenReturn(result); -// -// RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); -// ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); -// APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); -// apnSender.setApnFallbackManager(fallbackManager); -// -// when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); -// when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis()); -// -// ListenableFuture sendFuture = apnSender.sendMessage(message); -// ApnResult apnResult = sendFuture.get(); -// -// Thread.sleep(1000); // =( -// -// ArgumentCaptor notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); -// verify(apnsClient, times(1)).sendNotification(notification.capture()); -// -// assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); -// assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); -// assertThat(notification.getValue().getPayload()).isEqualTo("message"); -// assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); -// -// assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); -// -// verifyNoMoreInteractions(apnsClient); -// verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); -// verify(destinationAccount, times(1)).getDevice(1); -// verify(destinationDevice, times(1)).getApnId(); -// verify(destinationDevice, times(1)).getPushTimestamp(); -// -// verifyNoMoreInteractions(destinationDevice); -// verifyNoMoreInteractions(destinationAccount); -// verifyNoMoreInteractions(accountsManager); -// verifyNoMoreInteractions(fallbackManager); -// } + @Test + public void testRecentUnregisteredUser() throws Exception { + ApnsClient apnsClient = mock(ApnsClient.class); + + PushNotificationResponse response = mock(PushNotificationResponse.class); + when(response.isAccepted()).thenReturn(false); + when(response.getRejectionReason()).thenReturn("Unregistered"); + + DefaultPromise> result = new DefaultPromise<>(executor); + result.setSuccess(response); + + when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class))) + .thenReturn(result); + + RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); + APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); + apnSender.setApnFallbackManager(fallbackManager); + + when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID); + when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis()); + + ListenableFuture sendFuture = apnSender.sendMessage(message); + ApnResult apnResult = sendFuture.get(); + + Thread.sleep(1000); // =( + + ArgumentCaptor notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class); + verify(apnsClient, times(1)).sendNotification(notification.capture()); + + assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); + assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); + + assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.NO_SUCH_USER); + + verifyNoMoreInteractions(apnsClient); + verify(accountsManager, times(1)).get(eq(DESTINATION_NUMBER)); + verify(destinationAccount, times(1)).getDevice(1); + verify(destinationDevice, times(1)).getApnId(); + verify(destinationDevice, times(1)).getPushTimestamp(); + + verifyNoMoreInteractions(destinationDevice); + verifyNoMoreInteractions(destinationAccount); + verifyNoMoreInteractions(accountsManager); + verifyNoMoreInteractions(fallbackManager); + } // @Test // public void testUnregisteredUserOldApnId() throws Exception { @@ -343,7 +342,7 @@ public class APNSenderTest { .thenReturn(result); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); - ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); apnSender.setApnFallbackManager(fallbackManager); @@ -354,8 +353,8 @@ public class APNSenderTest { verify(apnsClient, times(1)).sendNotification(notification.capture()); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); - assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); - assertThat(notification.getValue().getPayload()).isEqualTo("message"); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.GENERIC_FAILURE); @@ -384,7 +383,7 @@ public class APNSenderTest { .thenReturn(connectedResult); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 10); - ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); apnSender.setApnFallbackManager(fallbackManager); @@ -409,8 +408,8 @@ public class APNSenderTest { verify(apnsClient, times(1)).getReconnectionFuture(); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); - assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); - assertThat(notification.getValue().getPayload()).isEqualTo("message"); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); assertThat(apnResult.getStatus()).isEqualTo(ApnResult.Status.SUCCESS); @@ -434,7 +433,7 @@ public class APNSenderTest { .thenReturn(result); RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient, 3); - ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, "message", true, 30); + ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true); APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false); apnSender.setApnFallbackManager(fallbackManager); @@ -451,8 +450,8 @@ public class APNSenderTest { verify(apnsClient, times(4)).sendNotification(notification.capture()); assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID); - assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(30)); - assertThat(notification.getValue().getPayload()).isEqualTo("message"); + assertThat(notification.getValue().getExpiration()).isEqualTo(new Date(ApnMessage.MAX_EXPIRATION)); + assertThat(notification.getValue().getPayload()).isEqualTo(ApnMessage.APN_PAYLOAD); assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE); verifyNoMoreInteractions(apnsClient); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java deleted file mode 100644 index ad91daf2a..000000000 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java +++ /dev/null @@ -1,78 +0,0 @@ -package org.whispersystems.textsecuregcm.tests.push; - -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.push.APNSender; -import org.whispersystems.textsecuregcm.push.ApnFallbackManager; -import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; -import org.whispersystems.textsecuregcm.push.ApnMessage; -import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.PubSubProtos; -import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; - -import java.util.List; - -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -public class ApnFallbackManagerTest { - - @Test - public void testFullFallback() throws Exception { - APNSender apnSender = mock(APNSender.class ); - PubSubManager pubSubManager = mock(PubSubManager.class); - WebsocketAddress address = new WebsocketAddress("+14152222223", 1L); - WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); - ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111); - ApnFallbackTask task = new ApnFallbackTask("foo", "voipfoo", message, 500, 0); - - ApnFallbackManager apnFallbackManager = new ApnFallbackManager(apnSender, pubSubManager); - apnFallbackManager.start(); - - apnFallbackManager.schedule(address, task); - - Util.sleep(1100); - - ArgumentCaptor captor = ArgumentCaptor.forClass(ApnMessage.class); - verify(apnSender, times(2)).sendMessage(captor.capture()); - verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager)); - - List arguments = captor.getAllValues(); - - assertEquals(arguments.get(0).getMessage(), message.getMessage()); - assertEquals(arguments.get(0).getApnId(), task.getVoipApnId()); -// assertEquals(arguments.get(0).getExpirationTime(), Integer.MAX_VALUE * 1000L); - - assertEquals(arguments.get(1).getMessage(), message.getMessage()); - assertEquals(arguments.get(1).getApnId(), task.getApnId()); - assertEquals(arguments.get(1).getExpirationTime(), Integer.MAX_VALUE * 1000L); - } - - @Test - public void testNoFallback() throws Exception { - APNSender pushServiceClient = mock(APNSender.class); - PubSubManager pubSubManager = mock(PubSubManager.class); - WebsocketAddress address = new WebsocketAddress("+14152222222", 1); - WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); - ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 5555); - ApnFallbackTask task = new ApnFallbackTask ("foo", "voipfoo", message, 500, 0); - - ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager); - apnFallbackManager.start(); - - apnFallbackManager.schedule(address, task); - apnFallbackManager.onDispatchMessage(info.serialize(), - PubSubProtos.PubSubMessage.newBuilder() - .setType(PubSubProtos.PubSubMessage.Type.CONNECTED) - .build().toByteArray()); - - verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager)); - - Util.sleep(1100); - - verifyNoMoreInteractions(pushServiceClient); - } - -} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java deleted file mode 100644 index 6bb16ae3f..000000000 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java +++ /dev/null @@ -1,92 +0,0 @@ -package org.whispersystems.textsecuregcm.tests.push; - - -import org.junit.Test; -import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; -import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTaskQueue; -import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; - -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ApnFallbackTaskQueueTest { - - @Test - public void testBlocking() { - final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); - - final WebsocketAddress address = mock(WebsocketAddress.class); - final ApnFallbackTask task = mock(ApnFallbackTask.class ); - - when(task.getExecutionTime()).thenReturn(System.currentTimeMillis() - 1000); - - new Thread() { - @Override - public void run() { - Util.sleep(500); - taskQueue.put(address, task); - } - }.start(); - - Map.Entry result = taskQueue.get(); - - assertEquals(result.getKey(), address); - assertEquals(result.getValue(), task); - } - - @Test - public void testElapsedTime() { - final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); - final WebsocketAddress address = mock(WebsocketAddress.class); - final ApnFallbackTask task = mock(ApnFallbackTask.class ); - - long currentTime = System.currentTimeMillis(); - - when(task.getExecutionTime()).thenReturn(currentTime + 1000); - - taskQueue.put(address, task); - Map.Entry result = taskQueue.get(); - - assertTrue(System.currentTimeMillis() >= currentTime + 1000); - assertEquals(result.getKey(), address); - assertEquals(result.getValue(), task); - } - - @Test - public void testCanceled() { - final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); - final WebsocketAddress addressOne = mock(WebsocketAddress.class); - final ApnFallbackTask taskOne = mock(ApnFallbackTask.class ); - final WebsocketAddress addressTwo = mock(WebsocketAddress.class); - final ApnFallbackTask taskTwo = mock(ApnFallbackTask.class ); - - long currentTime = System.currentTimeMillis(); - - when(taskOne.getExecutionTime()).thenReturn(currentTime + 1000); - when(taskTwo.getExecutionTime()).thenReturn(currentTime + 2000); - - taskQueue.put(addressOne, taskOne); - taskQueue.put(addressTwo, taskTwo); - - new Thread() { - @Override - public void run() { - Util.sleep(300); - taskQueue.remove(addressOne); - } - }.start(); - - Map.Entry result = taskQueue.get(); - - assertTrue(System.currentTimeMillis() >= currentTime + 2000); - assertEquals(result.getKey(), addressTwo); - assertEquals(result.getValue(), taskTwo); - } - - -} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 2e10578c3..fb7eb8c34 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -11,6 +11,7 @@ import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.WebsocketSender; @@ -58,12 +59,13 @@ public class WebSocketConnectionTest { private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); private static final PushSender pushSender = mock(PushSender.class); private static final ReceiptSender receiptSender = mock(ReceiptSender.class); + private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); @Test public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, storedMessages, pubSubManager); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) @@ -250,7 +252,7 @@ public class WebSocketConnectionTest { verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.absent())); verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class)); - verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(true)); + verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device)); connection.onDispatchUnsubscribed(websocketAddress.serialize()); verify(client).close(anyInt(), anyString());