diff --git a/protobuf/PubSubMessage.proto b/protobuf/PubSubMessage.proto index 980ab4840..cbf5c3888 100644 --- a/protobuf/PubSubMessage.proto +++ b/protobuf/PubSubMessage.proto @@ -26,6 +26,7 @@ message PubSubMessage { DELIVER = 2; KEEPALIVE = 3; CLOSE = 4; + CONNECTED = 5; } optional Type type = 1; diff --git a/src/main/java/org/whispersystems/dispatch/DispatchManager.java b/src/main/java/org/whispersystems/dispatch/DispatchManager.java index e52f104da..6fa5f292f 100644 --- a/src/main/java/org/whispersystems/dispatch/DispatchManager.java +++ b/src/main/java/org/whispersystems/dispatch/DispatchManager.java @@ -78,7 +78,7 @@ public class DispatchManager extends Thread { public boolean hasSubscription(String name) { return subscriptions.containsKey(name); } - + @Override public void run() { while (running) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 60931fb83..76c0a4b49 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -176,7 +176,7 @@ public class WhisperServerService extends Application nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration()); SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational()); diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageProtos.java b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageProtos.java index f2b7060de..799beaaa5 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageProtos.java +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageProtos.java @@ -76,7 +76,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-     * Contains an encrypted DataMessage
+     * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
      * 
*/ boolean hasLegacyMessage(); @@ -84,7 +84,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-     * Contains an encrypted DataMessage
+     * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
      * 
*/ com.google.protobuf.ByteString getLegacyMessage(); @@ -489,7 +489,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-     * Contains an encrypted DataMessage
+     * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
      * 
*/ public boolean hasLegacyMessage() { @@ -499,7 +499,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-     * Contains an encrypted DataMessage
+     * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
      * 
*/ public com.google.protobuf.ByteString getLegacyMessage() { @@ -1119,7 +1119,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-       * Contains an encrypted DataMessage
+       * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
        * 
*/ public boolean hasLegacyMessage() { @@ -1129,7 +1129,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-       * Contains an encrypted DataMessage
+       * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
        * 
*/ public com.google.protobuf.ByteString getLegacyMessage() { @@ -1139,7 +1139,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-       * Contains an encrypted DataMessage
+       * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
        * 
*/ public Builder setLegacyMessage(com.google.protobuf.ByteString value) { @@ -1155,7 +1155,7 @@ public final class MessageProtos { * optional bytes legacyMessage = 6; * *
-       * Contains an encrypted DataMessage
+       * Contains an encrypted DataMessage XXX -- Remove after 10/01/15
        * 
*/ public Builder clearLegacyMessage() { diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java b/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java index 4c0d49f2f..83f01200f 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/ApnFallbackManager.java @@ -6,11 +6,16 @@ import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.RatioGauge; import com.codahale.metrics.SharedMetricRegistries; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.entities.ApnMessage; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import java.util.Iterator; @@ -21,7 +26,7 @@ import java.util.concurrent.TimeUnit; import static com.codahale.metrics.MetricRegistry.name; import io.dropwizard.lifecycle.Managed; -public class ApnFallbackManager implements Managed, Runnable { +public class ApnFallbackManager implements Managed, Runnable, DispatchChannel { private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class); @@ -35,21 +40,28 @@ public class ApnFallbackManager implements Managed, Runnable { } private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue(); - private final PushServiceClient pushServiceClient; - public ApnFallbackManager(PushServiceClient pushServiceClient) { + private final PushServiceClient pushServiceClient; + private final PubSubManager pubSubManager; + + public ApnFallbackManager(PushServiceClient pushServiceClient, PubSubManager pubSubManager) { this.pushServiceClient = pushServiceClient; + this.pubSubManager = pubSubManager; } public void schedule(final WebsocketAddress address, ApnFallbackTask task) { voipOneDelivery.mark(); - taskQueue.put(address, task); + + if (taskQueue.put(address, task)) { + pubSubManager.subscribe(new WebSocketConnectionInfo(address), this); + } } - public void cancel(WebsocketAddress address) { + private void cancel(WebsocketAddress address) { ApnFallbackTask task = taskQueue.remove(address); if (task != null) { + pubSubManager.unsubscribe(new WebSocketConnectionInfo(address), this); voipOneSuccess.mark(); voipOneSuccessHistogram.update(System.currentTimeMillis() - task.getScheduledTime()); } @@ -72,6 +84,7 @@ public class ApnFallbackManager implements Managed, Runnable { Entry taskEntry = taskQueue.get(); ApnFallbackTask task = taskEntry.getValue(); + pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this); pushServiceClient.send(new ApnMessage(task.getMessage(), task.getApnId(), false, ApnMessage.MAX_EXPIRATION)); } catch (Throwable e) { @@ -80,6 +93,31 @@ public class ApnFallbackManager implements Managed, Runnable { } } + @Override + public void onDispatchMessage(String channel, byte[] message) { + try { + PubSubMessage notification = PubSubMessage.parseFrom(message); + + if (notification.getType().getNumber() == PubSubMessage.Type.CONNECTED_VALUE) { + WebSocketConnectionInfo address = new WebSocketConnectionInfo(channel); + cancel(address.getWebsocketAddress()); + } else { + logger.warn("Got strange pubsub type: " + notification.getType().getNumber()); + } + + } catch (WebSocketConnectionInfo.FormattingException e) { + logger.warn("Bad formatting?", e); + } catch (InvalidProtocolBufferException e) { + logger.warn("Bad protobuf", e); + } + } + + @Override + public void onDispatchSubscribed(String channel) {} + + @Override + public void onDispatchUnsubscribed(String channel) {} + public static class ApnFallbackTask { private final long delay; @@ -147,10 +185,12 @@ public class ApnFallbackManager implements Managed, Runnable { } } - public void put(WebsocketAddress address, ApnFallbackTask task) { + public boolean put(WebsocketAddress address, ApnFallbackTask task) { synchronized (tasks) { - tasks.put(address, task); + ApnFallbackTask previous = tasks.put(address, task); tasks.notifyAll(); + + return previous == null; } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java new file mode 100644 index 000000000..55c1a05cf --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubAddress.java @@ -0,0 +1,5 @@ +package org.whispersystems.textsecuregcm.storage; + +public interface PubSubAddress { + public String serialize(); +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java index 7eef751e4..f8ba60d6a 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java @@ -1,13 +1,9 @@ package org.whispersystems.textsecuregcm.storage; -import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.dispatch.DispatchManager; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; - -import java.util.concurrent.atomic.AtomicInteger; import io.dropwizard.lifecycle.Managed; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; @@ -21,7 +17,7 @@ public class PubSubManager implements Managed { private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); private final DispatchManager dispatchManager; - private final JedisPool jedisPool; + private final JedisPool jedisPool; private boolean subscribed = false; @@ -49,21 +45,19 @@ public class PubSubManager implements Managed { dispatchManager.shutdown(); } - public void subscribe(WebsocketAddress address, DispatchChannel channel) { - String serializedAddress = address.serialize(); - dispatchManager.subscribe(serializedAddress, channel); + public void subscribe(PubSubAddress address, DispatchChannel channel) { + dispatchManager.subscribe(address.serialize(), channel); } - public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) { - String serializedAddress = address.serialize(); - dispatchManager.unsubscribe(serializedAddress, dispatchChannel); + public void unsubscribe(PubSubAddress address, DispatchChannel dispatchChannel) { + dispatchManager.unsubscribe(address.serialize(), dispatchChannel); } - public boolean hasLocalSubscription(WebsocketAddress address) { + public boolean hasLocalSubscription(PubSubAddress address) { return dispatchManager.hasSubscription(address.serialize()); } - public boolean publish(WebsocketAddress address, PubSubMessage message) { + public boolean publish(PubSubAddress address, PubSubMessage message) { return publish(address.serialize().getBytes(), message); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java index 7ca6225e3..6f9b91ac9 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java @@ -162,6 +162,10 @@ public final class PubSubProtos { * CLOSE = 4; */ CLOSE(4, 4), + /** + * CONNECTED = 5; + */ + CONNECTED(5, 5), ; /** @@ -184,6 +188,10 @@ public final class PubSubProtos { * CLOSE = 4; */ public static final int CLOSE_VALUE = 4; + /** + * CONNECTED = 5; + */ + public static final int CONNECTED_VALUE = 5; public final int getNumber() { return value; } @@ -195,6 +203,7 @@ public final class PubSubProtos { case 2: return DELIVER; case 3: return KEEPALIVE; case 4: return CLOSE; + case 5: return CONNECTED; default: return null; } } @@ -620,13 +629,13 @@ public final class PubSubProtos { descriptor; static { java.lang.String[] descriptorData = { - "\n\023PubSubMessage.proto\022\ntextsecure\"\230\001\n\rPu" + + "\n\023PubSubMessage.proto\022\ntextsecure\"\247\001\n\rPu" + "bSubMessage\022,\n\004type\030\001 \001(\0162\036.textsecure.P" + - "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"H\n\004Ty" + + "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"W\n\004Ty" + "pe\022\013\n\007UNKNOWN\020\000\022\014\n\010QUERY_DB\020\001\022\013\n\007DELIVER" + - "\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004B8\n(org.whis" + - "persystems.textsecuregcm.storageB\014PubSub" + - "Protos" + "\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004\022\r\n\tCONNECTE" + + "D\020\005B8\n(org.whispersystems.textsecuregcm." + + "storageB\014PubSubProtos" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java index 1fe77389b..441a70b64 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/util/Util.java +++ b/src/main/java/org/whispersystems/textsecuregcm/util/Util.java @@ -20,6 +20,7 @@ import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.Arrays; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -129,6 +130,10 @@ public class Util { } } + public static int hashCode(Object... objects) { + return Arrays.hashCode(objects); + } + public static long todayInMillis() { return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis())); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 032fb9d1d..99c5a34c0 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -13,6 +13,8 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.session.WebSocketSessionContext; @@ -47,15 +49,16 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final Account account = context.getAuthenticated(Account.class); - final Device device = account.getAuthenticatedDevice().get(); - final long connectTime = System.currentTimeMillis(); - final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, - messagesManager, account, device, - context.getClient()); + final Account account = context.getAuthenticated(Account.class); + final Device device = account.getAuthenticatedDevice().get(); + final long connectTime = System.currentTimeMillis(); + final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); + final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, + messagesManager, account, device, + context.getClient()); - apnFallbackManager.cancel(address); + pubSubManager.publish(info, PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED).build()); updateLastSeen(account, device); pubSubManager.subscribe(address, connection); diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java new file mode 100644 index 000000000..dcb8372d4 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionInfo.java @@ -0,0 +1,58 @@ +package org.whispersystems.textsecuregcm.websocket; + +import org.whispersystems.textsecuregcm.storage.PubSubAddress; +import org.whispersystems.textsecuregcm.util.Util; + +public class WebSocketConnectionInfo implements PubSubAddress { + + private final WebsocketAddress address; + + public WebSocketConnectionInfo(WebsocketAddress address) { + this.address = address; + } + + public WebSocketConnectionInfo(String serialized) throws FormattingException { + String[] parts = serialized.split("[:]", 3); + + if (parts.length != 3 || !"c".equals(parts[2])) { + throw new FormattingException("Bad address: " + serialized); + } + + try { + this.address = new WebsocketAddress(parts[0], Long.parseLong(parts[1])); + } catch (NumberFormatException e) { + throw new FormattingException(e); + } + } + + public String serialize() { + return address.serialize() + ":c"; + } + + public WebsocketAddress getWebsocketAddress() { + return address; + } + + @Override + public boolean equals(Object other) { + return + other != null && + other instanceof WebSocketConnectionInfo + && ((WebSocketConnectionInfo)other).address.equals(address); + } + + @Override + public int hashCode() { + return Util.hashCode(address, "c"); + } + + public static class FormattingException extends Exception { + public FormattingException(String message) { + super(message); + } + + public FormattingException(Exception e) { + super(e); + } + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java index 3ad5d944a..8009e9d0d 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java @@ -1,6 +1,8 @@ package org.whispersystems.textsecuregcm.websocket; -public class WebsocketAddress { +import org.whispersystems.textsecuregcm.storage.PubSubAddress; + +public class WebsocketAddress implements PubSubAddress { private final String number; private final long deviceId; diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java index 67e611b91..bfd47aca7 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/push/ApnFallbackManagerTest.java @@ -6,7 +6,10 @@ import org.whispersystems.textsecuregcm.entities.ApnMessage; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask; import org.whispersystems.textsecuregcm.push.PushServiceClient; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import static org.junit.Assert.assertEquals; @@ -18,11 +21,13 @@ public class ApnFallbackManagerTest { @Test public void testFullFallback() throws Exception { PushServiceClient pushServiceClient = mock(PushServiceClient.class); - WebsocketAddress address = mock(WebsocketAddress.class ); + PubSubManager pubSubManager = mock(PubSubManager.class); + WebsocketAddress address = new WebsocketAddress("+14152222223", 1L); + WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111); ApnFallbackTask task = new ApnFallbackTask("foo", message, 500); - ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient); + ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager); apnFallbackManager.start(); apnFallbackManager.schedule(address, task); @@ -31,6 +36,7 @@ public class ApnFallbackManagerTest { ArgumentCaptor captor = ArgumentCaptor.forClass(ApnMessage.class); verify(pushServiceClient, times(1)).send(captor.capture()); + verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager)); assertEquals(captor.getValue().getMessage(), message.getMessage()); assertEquals(captor.getValue().getApnId(), task.getApnId()); @@ -41,15 +47,22 @@ public class ApnFallbackManagerTest { @Test public void testNoFallback() throws Exception { PushServiceClient pushServiceClient = mock(PushServiceClient.class); - WebsocketAddress address = mock(WebsocketAddress.class ); + PubSubManager pubSubManager = mock(PubSubManager.class); + WebsocketAddress address = new WebsocketAddress("+14152222222", 1); + WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 5555); ApnFallbackTask task = new ApnFallbackTask ("foo", message, 500); - ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient); + ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager); apnFallbackManager.start(); apnFallbackManager.schedule(address, task); - apnFallbackManager.cancel(address); + apnFallbackManager.onDispatchMessage(info.serialize(), + PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.CONNECTED) + .build().toByteArray()); + + verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager)); Util.sleep(1100);