diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d91141162..ca61a71ce 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -193,6 +193,7 @@ public class WhisperServerService extends Application messageBody = getMessageBody(incomingMessage); @@ -218,9 +218,6 @@ public class MessageController { } catch (NotPushRegisteredException e) { if (destinationDevice.isMaster()) throw new NoSuchUserException(e); else logger.debug("Not registered", e); - } catch (TransientPushFailureException e) { - if (destinationDevice.isMaster()) throw new IOException(e); - else logger.debug("Transient failure", e); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index 69ec7ef41..6582e5aea 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -16,6 +16,9 @@ */ package org.whispersystems.textsecuregcm.push; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.ApnMessage; @@ -24,36 +27,56 @@ import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor; +import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.concurrent.TimeUnit; +import static com.codahale.metrics.MetricRegistry.name; +import io.dropwizard.lifecycle.Managed; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -public class PushSender { +public class PushSender implements Managed { private final Logger logger = LoggerFactory.getLogger(PushSender.class); private static final String APN_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"badge\":%d,\"alert\":{\"loc-key\":\"APN_Message\"}}}"; - private final ApnFallbackManager apnFallbackManager; - private final PushServiceClient pushServiceClient; - private final WebsocketSender webSocketSender; + private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private final Histogram queueDepthGauge = metricRegistry.histogram(name(getClass(), "queue_depth")); + + private final ApnFallbackManager apnFallbackManager; + private final PushServiceClient pushServiceClient; + private final WebsocketSender webSocketSender; + private final BlockingThreadPoolExecutor executor; public PushSender(ApnFallbackManager apnFallbackManager, PushServiceClient pushServiceClient, WebsocketSender websocketSender) { this.apnFallbackManager = apnFallbackManager; this.pushServiceClient = pushServiceClient; this.webSocketSender = websocketSender; + this.executor = new BlockingThreadPoolExecutor(50, 200); } - public void sendMessage(Account account, Device device, Envelope message) - throws NotPushRegisteredException, TransientPushFailureException + public void sendMessage(final Account account, final Device device, final Envelope message) + throws NotPushRegisteredException { - if (device.getGcmId() != null) sendGcmMessage(account, device, message); - else if (device.getApnId() != null) sendApnMessage(account, device, message); - else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); - else throw new NotPushRegisteredException("No delivery possible!"); + if (device.getGcmId() == null && device.getApnId() == null && !device.getFetchesMessages()) { + throw new NotPushRegisteredException("No delivery possible!"); + } + + executor.execute(new Runnable() { + @Override + public void run() { + if (device.getGcmId() != null) sendGcmMessage(account, device, message); + else if (device.getApnId() != null) sendApnMessage(account, device, message); + else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); + else throw new AssertionError(); + } + }); + + queueDepthGauge.update(executor.getSize()); } public void sendQueuedNotification(Account account, Device device, int messageQueueDepth) @@ -68,9 +91,7 @@ public class PushSender { return webSocketSender; } - private void sendGcmMessage(Account account, Device device, Envelope message) - throws TransientPushFailureException - { + private void sendGcmMessage(Account account, Device device, Envelope message) { DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, message, WebsocketSender.Type.GCM); if (!deliveryStatus.isDelivered()) { @@ -78,18 +99,18 @@ public class PushSender { } } - private void sendGcmNotification(Account account, Device device) - throws TransientPushFailureException - { - GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(), - (int)device.getId(), "", false, true); + private void sendGcmNotification(Account account, Device device) { + try { + GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(), + (int)device.getId(), "", false, true); - pushServiceClient.send(gcmMessage); + pushServiceClient.send(gcmMessage); + } catch (TransientPushFailureException e) { + logger.warn("SILENT PUSH LOSS", e); + } } - private void sendApnMessage(Account account, Device device, Envelope outgoingMessage) - throws TransientPushFailureException - { + private void sendApnMessage(Account account, Device device, Envelope outgoingMessage) { DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN); if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != Envelope.Type.RECEIPT) { @@ -97,9 +118,7 @@ public class PushSender { } } - private void sendApnNotification(Account account, Device device, int messageQueueDepth) - throws TransientPushFailureException - { + private void sendApnNotification(Account account, Device device, int messageQueueDepth) { ApnMessage apnMessage; if (!Util.isEmpty(device.getVoipApnId())) { @@ -115,11 +134,26 @@ public class PushSender { false, ApnMessage.MAX_EXPIRATION); } - pushServiceClient.send(apnMessage); + try { + pushServiceClient.send(apnMessage); + } catch (TransientPushFailureException e) { + logger.warn("SILENT PUSH LOSS", e); + } } private void sendWebSocketMessage(Account account, Device device, Envelope outgoingMessage) { webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.WEB); } + + @Override + public void start() throws Exception { + + } + + @Override + public void stop() throws Exception { + executor.shutdown(); + executor.awaitTermination(5, TimeUnit.MINUTES); + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/BlockingThreadPoolExecutor.java b/src/main/java/org/whispersystems/textsecuregcm/util/BlockingThreadPoolExecutor.java new file mode 100644 index 000000000..a0eadd1d8 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/util/BlockingThreadPoolExecutor.java @@ -0,0 +1,37 @@ +package org.whispersystems.textsecuregcm.util; + +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +public class BlockingThreadPoolExecutor extends ThreadPoolExecutor { + + private final Semaphore semaphore; + + public BlockingThreadPoolExecutor(int threads, int bound) { + super(threads, threads, 1, TimeUnit.SECONDS, new LinkedBlockingQueue()); + this.semaphore = new Semaphore(bound); + } + + @Override + public void execute(Runnable task) { + semaphore.acquireUninterruptibly(); + + try { + super.execute(task); + } catch (Throwable t) { + semaphore.release(); + throw new RuntimeException(t); + } + } + + @Override + protected void afterExecute(Runnable r, Throwable t) { + semaphore.release(); + } + + public int getSize() { + return ((LinkedBlockingQueue)getQueue()).size(); + } +} \ No newline at end of file diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/util/BlockingThreadPoolExecutorTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/util/BlockingThreadPoolExecutorTest.java new file mode 100644 index 000000000..cc9b326f1 --- /dev/null +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/util/BlockingThreadPoolExecutorTest.java @@ -0,0 +1,58 @@ +package org.whispersystems.textsecuregcm.tests.util; + +import org.junit.Test; +import org.whispersystems.textsecuregcm.util.BlockingThreadPoolExecutor; +import org.whispersystems.textsecuregcm.util.Util; + +import static org.junit.Assert.assertTrue; + +public class BlockingThreadPoolExecutorTest { + + @Test + public void testBlocking() { + BlockingThreadPoolExecutor executor = new BlockingThreadPoolExecutor(1, 3); + long start = System.currentTimeMillis(); + + executor.execute(new Runnable() { + @Override + public void run() { + Util.sleep(1000); + } + }); + + assertTrue(System.currentTimeMillis() - start < 500); + start = System.currentTimeMillis(); + + executor.execute(new Runnable() { + @Override + public void run() { + Util.sleep(1000); + } + }); + + assertTrue(System.currentTimeMillis() - start < 500); + + start = System.currentTimeMillis(); + + executor.execute(new Runnable() { + @Override + public void run() { + Util.sleep(1000); + } + }); + + assertTrue(System.currentTimeMillis() - start < 500); + + start = System.currentTimeMillis(); + + executor.execute(new Runnable() { + @Override + public void run() { + Util.sleep(1000); + } + }); + + assertTrue(System.currentTimeMillis() - start > 500); + } + +}