diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index 01a0e4ee1..94c5be7be 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -87,11 +87,11 @@ public class PushSender implements Managed { } } - public void sendQueuedNotification(Account account, Device device, int messageQueueDepth) + public void sendQueuedNotification(Account account, Device device, int messageQueueDepth, boolean fallback) throws NotPushRegisteredException, TransientPushFailureException { if (device.getGcmId() != null) sendGcmNotification(account, device); - else if (device.getApnId() != null) sendApnNotification(account, device, messageQueueDepth); + else if (device.getApnId() != null) sendApnNotification(account, device, messageQueueDepth, fallback); else if (!device.getFetchesMessages()) throw new NotPushRegisteredException("No notification possible!"); } @@ -129,11 +129,12 @@ public class PushSender implements Managed { DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN); if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) { - sendApnNotification(account, device, deliveryStatus.getMessageQueueDepth()); + boolean fallback = !outgoingMessage.getSource().equals(account.getNumber()); + sendApnNotification(account, device, deliveryStatus.getMessageQueueDepth(), fallback); } } - private void sendApnNotification(Account account, Device device, int messageQueueDepth) { + private void sendApnNotification(Account account, Device device, int messageQueueDepth, boolean fallback) { ApnMessage apnMessage; if (!Util.isEmpty(device.getVoipApnId())) { @@ -141,8 +142,10 @@ public class PushSender implements Managed { String.format(APN_PAYLOAD, messageQueueDepth), true, System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(30)); - apnFallbackManager.schedule(new WebsocketAddress(account.getNumber(), device.getId()), - new ApnFallbackTask(device.getApnId(), apnMessage)); + if (fallback) { + apnFallbackManager.schedule(new WebsocketAddress(account.getNumber(), device.getId()), + new ApnFallbackTask(device.getApnId(), apnMessage)); + } } else { apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(), String.format(APN_PAYLOAD, messageQueueDepth), diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index c1ac498dd..5b6ec41cb 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -140,10 +140,11 @@ public class WebSocketConnection implements DispatchChannel { } private void requeueMessage(Envelope message) { - int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message); + int queueDepth = pushSender.getWebSocketSender().queueMessage(account, device, message); + boolean fallback = !message.getSource().equals(account.getNumber()); try { - pushSender.sendQueuedNotification(account, device, queueDepth); + pushSender.sendQueuedNotification(account, device, queueDepth, fallback); } catch (NotPushRegisteredException | TransientPushFailureException e) { logger.warn("requeueMessage", e); } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 910b06372..4d83b0e56 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -250,7 +250,7 @@ public class WebSocketConnectionTest { verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.absent())); verify(websocketSender, times(1)).queueMessage(eq(account), eq(device), any(Envelope.class)); - verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(10)); + verify(pushSender, times(1)).sendQueuedNotification(eq(account), eq(device), eq(10), eq(true)); connection.onDispatchUnsubscribed(websocketAddress.serialize()); verify(client).close(anyInt(), anyString());