diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 16f3e184e..a1c5434d8 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -55,6 +55,7 @@ import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge; import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisHealthCheck; import org.whispersystems.textsecuregcm.providers.TimeProvider; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.FeedbackHandler; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushServiceClient; @@ -170,15 +171,17 @@ public class WhisperServerService extends Application nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration()); SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational()); UrlSigner urlSigner = new UrlSigner(config.getS3Configuration()); - PushSender pushSender = new PushSender(pushServiceClient, websocketSender); + PushSender pushSender = new PushSender(apnFallbackManager, pushServiceClient, websocketSender); ReceiptSender receiptSender = new ReceiptSender(accountsManager, pushSender, federatedClientManager); FeedbackHandler feedbackHandler = new FeedbackHandler(pushServiceClient, accountsManager); Optional authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey(); + environment.lifecycle().manage(apnFallbackManager); environment.lifecycle().manage(pubSubManager); environment.lifecycle().manage(feedbackHandler); @@ -207,7 +210,7 @@ public class WhisperServerService extends Application taskEntry = taskQueue.get(); + ApnFallbackTask task = taskEntry.getValue(); + int retryCount = task.getRetryCount(); + + if (retryCount == 0) { + pushServiceClient.send(task.getMessage()); + schedule(taskEntry.getKey(), new ApnFallbackTask(task.getApnId(), task.getMessage(), + retryCount + 1, task.getDelay())); + } else if (retryCount == 1) { + pushServiceClient.send(new ApnMessage(task.getMessage(), task.getApnId(), false)); + } + } catch (Throwable e) { + logger.warn("ApnFallbackThread", e); + } + } + } + + public static class ApnFallbackTask { + + private final long delay; + private final long executionTime; + private final String apnId; + private final ApnMessage message; + private final int retryCount; + + public ApnFallbackTask(String apnId, ApnMessage message, int retryCount) { + this(apnId, message, retryCount, TimeUnit.SECONDS.toMillis(15)); + } + + @VisibleForTesting + public ApnFallbackTask(String apnId, ApnMessage message, int retryCount, long delay) { + this.executionTime = System.currentTimeMillis() + delay; + this.delay = delay; + this.apnId = apnId; + this.message = message; + this.retryCount = retryCount; + } + + public String getApnId() { + return apnId; + } + + public ApnMessage getMessage() { + return message; + } + + public int getRetryCount() { + return retryCount; + } + + public long getExecutionTime() { + return executionTime; + } + + public long getDelay() { + return delay; + } + } + + @VisibleForTesting + public static class ApnFallbackTaskQueue { + + private final LinkedHashMap tasks = new LinkedHashMap<>(); + + public Entry get() { + while (true) { + long timeDelta; + + synchronized (tasks) { + while (tasks.isEmpty()) Util.wait(tasks); + + Iterator> iterator = tasks.entrySet().iterator(); + Entry nextTask = iterator.next(); + + timeDelta = nextTask.getValue().getExecutionTime() - System.currentTimeMillis(); + + if (timeDelta <= 0) { + iterator.remove(); + return nextTask; + } + } + + Util.sleep(timeDelta); + } + } + + public void put(WebsocketAddress address, ApnFallbackTask task) { + synchronized (tasks) { + tasks.put(address, task); + tasks.notifyAll(); + } + } + + public void remove(WebsocketAddress address) { + synchronized (tasks) { + tasks.remove(address); + } + } + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index feda80cd1..1e97e15de 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -22,11 +22,16 @@ import org.whispersystems.textsecuregcm.entities.ApnMessage; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.GcmMessage; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; +import java.util.LinkedHashMap; + +import io.dropwizard.lifecycle.Managed; import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; public class PushSender { @@ -35,12 +40,14 @@ public class PushSender { private static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"badge\":%d,\"alert\":{\"loc-key\":\"APN_Message\"}}}"; - private final PushServiceClient pushServiceClient; - private final WebsocketSender webSocketSender; + private final ApnFallbackManager apnFallbackManager; + private final PushServiceClient pushServiceClient; + private final WebsocketSender webSocketSender; - public PushSender(PushServiceClient pushServiceClient, WebsocketSender websocketSender) { - this.pushServiceClient = pushServiceClient; - this.webSocketSender = websocketSender; + public PushSender(ApnFallbackManager apnFallbackManager, PushServiceClient pushServiceClient, WebsocketSender websocketSender) { + this.apnFallbackManager = apnFallbackManager; + this.pushServiceClient = pushServiceClient; + this.webSocketSender = websocketSender; } public void sendMessage(Account account, Device device, OutgoingMessageSignal message) @@ -106,6 +113,12 @@ public class PushSender { ApnMessage apnMessage = new ApnMessage(apnId, account.getNumber(), (int)device.getId(), String.format(APN_PAYLOAD, deliveryStatus.getMessageQueueDepth()), isVoip); + + if (isVoip) { + apnFallbackManager.schedule(new WebsocketAddress(account.getNumber(), device.getId()), + new ApnFallbackTask(device.getApnId(), apnMessage, 0)); + } + pushServiceClient.send(apnMessage); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java index 722062ec7..1fe77389b 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java +++ b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java @@ -115,12 +115,20 @@ public class Util { return parts; } - public static void sleep(int i) { + public static void sleep(long i) { try { Thread.sleep(i); } catch (InterruptedException ie) {} } + public static void wait(Object object) { + try { + object.wait(); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } + public static long todayInMillis() { return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis())); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index d3c5d8211..e30f45f4e 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -5,6 +5,7 @@ import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; @@ -25,22 +26,23 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration")); - - private final AccountsManager accountsManager; - private final PushSender pushSender; - private final ReceiptSender receiptSender; - private final MessagesManager messagesManager; - private final PubSubManager pubSubManager; + private final ApnFallbackManager apnFallbackManager; + private final AccountsManager accountsManager; + private final PushSender pushSender; + private final ReceiptSender receiptSender; + private final MessagesManager messagesManager; + private final PubSubManager pubSubManager; public AuthenticatedConnectListener(AccountsManager accountsManager, PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, - PubSubManager pubSubManager) + PubSubManager pubSubManager, ApnFallbackManager apnFallbackManager) { - this.accountsManager = accountsManager; - this.pushSender = pushSender; - this.receiptSender = receiptSender; - this.messagesManager = messagesManager; - this.pubSubManager = pubSubManager; + this.accountsManager = accountsManager; + this.pushSender = pushSender; + this.receiptSender = receiptSender; + this.messagesManager = messagesManager; + this.pubSubManager = pubSubManager; + this.apnFallbackManager = apnFallbackManager; } @Override @@ -53,6 +55,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { messagesManager, account, device, context.getClient()); + apnFallbackManager.cancel(address); updateLastSeen(account, device); pubSubManager.subscribe(address, connection); diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java new file mode 100644 index 000000000..86f765c3a --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java @@ -0,0 +1,83 @@ +package org.whispersystems.textsecuregcm.tests.push; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.entities.ApnMessage; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; +import org.whispersystems.textsecuregcm.push.PushServiceClient; +import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +public class ApnFallbackManagerTest { + + @Test + public void testFullFallback() throws Exception { + PushServiceClient pushServiceClient = mock(PushServiceClient.class); + WebsocketAddress address = mock(WebsocketAddress.class ); + ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true); + ApnFallbackTask task = new ApnFallbackTask("foo", message, 0, 500); + + ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient); + apnFallbackManager.start(); + + apnFallbackManager.schedule(address, task); + + Util.sleep(1100); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ApnMessage.class); + verify(pushServiceClient, times(2)).send(captor.capture()); + + List messages = captor.getAllValues(); + assertEquals(messages.get(0), message); + assertEquals(messages.get(1).getApnId(), task.getApnId()); + assertFalse(messages.get(1).isVoip()); + } + + @Test + public void testPartialFallback() throws Exception { + PushServiceClient pushServiceClient = mock(PushServiceClient.class); + WebsocketAddress address = mock(WebsocketAddress.class ); + ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true); + ApnFallbackTask task = new ApnFallbackTask ("foo", message, 0, 500); + + ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient); + apnFallbackManager.start(); + + apnFallbackManager.schedule(address, task); + + Util.sleep(600); + + apnFallbackManager.cancel(address); + + Util.sleep(600); + + verify(pushServiceClient, times(1)).send(eq(message)); + } + + @Test + public void testNoFallback() throws Exception { + PushServiceClient pushServiceClient = mock(PushServiceClient.class); + WebsocketAddress address = mock(WebsocketAddress.class ); + ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true); + ApnFallbackTask task = new ApnFallbackTask ("foo", message, 0, 500); + + ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient); + apnFallbackManager.start(); + + apnFallbackManager.schedule(address, task); + apnFallbackManager.cancel(address); + + Util.sleep(1100); + + verifyNoMoreInteractions(pushServiceClient); + } + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java new file mode 100644 index 000000000..a2617ec93 --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackTaskQueueTest.java @@ -0,0 +1,93 @@ +package org.whispersystems.textsecuregcm.tests.push; + + +import org.junit.Test; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTaskQueue; +import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ApnFallbackTaskQueueTest { + + @Test + public void testBlocking() { + final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); + + final WebsocketAddress address = mock(WebsocketAddress.class); + final ApnFallbackTask task = mock(ApnFallbackTask.class ); + + when(task.getExecutionTime()).thenReturn(System.currentTimeMillis() - 1000); + + new Thread() { + @Override + public void run() { + Util.sleep(500); + taskQueue.put(address, task); + } + }.start(); + + Map.Entry result = taskQueue.get(); + + assertEquals(result.getKey(), address); + assertEquals(result.getValue(), task); + } + + @Test + public void testElapsedTime() { + final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); + final WebsocketAddress address = mock(WebsocketAddress.class); + final ApnFallbackTask task = mock(ApnFallbackTask.class ); + + long currentTime = System.currentTimeMillis(); + + when(task.getExecutionTime()).thenReturn(currentTime + 1000); + + taskQueue.put(address, task); + Map.Entry result = taskQueue.get(); + + assertTrue(System.currentTimeMillis() >= currentTime + 1000); + assertEquals(result.getKey(), address); + assertEquals(result.getValue(), task); + } + + @Test + public void testCanceled() { + final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); + final WebsocketAddress addressOne = mock(WebsocketAddress.class); + final ApnFallbackTask taskOne = mock(ApnFallbackTask.class ); + final WebsocketAddress addressTwo = mock(WebsocketAddress.class); + final ApnFallbackTask taskTwo = mock(ApnFallbackTask.class ); + + long currentTime = System.currentTimeMillis(); + + when(taskOne.getExecutionTime()).thenReturn(currentTime + 1000); + when(taskTwo.getExecutionTime()).thenReturn(currentTime + 2000); + + taskQueue.put(addressOne, taskOne); + taskQueue.put(addressTwo, taskTwo); + + new Thread() { + @Override + public void run() { + Util.sleep(300); + taskQueue.remove(addressOne); + } + }.start(); + + Map.Entry result = taskQueue.get(); + + assertTrue(System.currentTimeMillis() >= currentTime + 2000); + assertEquals(result.getKey(), addressTwo); + assertEquals(result.getValue(), taskTwo); + } + + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 105aceb16..ff4bb1903 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -10,6 +10,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; @@ -60,12 +61,13 @@ public class WebSocketConnectionTest { private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); private static final PushSender pushSender = mock(PushSender.class); private static final ReceiptSender receiptSender = mock(ReceiptSender.class); + private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); @Test public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, storedMessages, pubSubManager); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, storedMessages, pubSubManager, apnFallbackManager); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))