Make APN fallback behave well in multi-server environments.

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2015-07-29 15:02:44 -07:00
parent 8d0d934249
commit d4e618893c
13 changed files with 179 additions and 49 deletions

View File

@ -26,6 +26,7 @@ message PubSubMessage {
DELIVER = 2;
KEEPALIVE = 3;
CLOSE = 4;
CONNECTED = 5;
}
optional Type type = 1;

View File

@ -78,7 +78,7 @@ public class DispatchManager extends Thread {
public boolean hasSubscription(String name) {
return subscriptions.containsKey(name);
}
@Override
public void run() {
while (running) {

View File

@ -176,7 +176,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FederatedPeerAuthenticator federatedPeerAuthenticator = new FederatedPeerAuthenticator(config.getFederationConfiguration());
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), cacheClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());

View File

@ -76,7 +76,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
boolean hasLegacyMessage();
@ -84,7 +84,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
com.google.protobuf.ByteString getLegacyMessage();
@ -489,7 +489,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public boolean hasLegacyMessage() {
@ -499,7 +499,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public com.google.protobuf.ByteString getLegacyMessage() {
@ -1119,7 +1119,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public boolean hasLegacyMessage() {
@ -1129,7 +1129,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public com.google.protobuf.ByteString getLegacyMessage() {
@ -1139,7 +1139,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public Builder setLegacyMessage(com.google.protobuf.ByteString value) {
@ -1155,7 +1155,7 @@ public final class MessageProtos {
* <code>optional bytes legacyMessage = 6;</code>
*
* <pre>
* Contains an encrypted DataMessage
* Contains an encrypted DataMessage XXX -- Remove after 10/01/15
* </pre>
*/
public Builder clearLegacyMessage() {

View File

@ -6,11 +6,16 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.RatioGauge;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Iterator;
@ -21,7 +26,7 @@ import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
public class ApnFallbackManager implements Managed, Runnable {
public class ApnFallbackManager implements Managed, Runnable, DispatchChannel {
private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class);
@ -35,21 +40,28 @@ public class ApnFallbackManager implements Managed, Runnable {
}
private final ApnFallbackTaskQueue taskQueue = new ApnFallbackTaskQueue();
private final PushServiceClient pushServiceClient;
public ApnFallbackManager(PushServiceClient pushServiceClient) {
private final PushServiceClient pushServiceClient;
private final PubSubManager pubSubManager;
public ApnFallbackManager(PushServiceClient pushServiceClient, PubSubManager pubSubManager) {
this.pushServiceClient = pushServiceClient;
this.pubSubManager = pubSubManager;
}
public void schedule(final WebsocketAddress address, ApnFallbackTask task) {
voipOneDelivery.mark();
taskQueue.put(address, task);
if (taskQueue.put(address, task)) {
pubSubManager.subscribe(new WebSocketConnectionInfo(address), this);
}
}
public void cancel(WebsocketAddress address) {
private void cancel(WebsocketAddress address) {
ApnFallbackTask task = taskQueue.remove(address);
if (task != null) {
pubSubManager.unsubscribe(new WebSocketConnectionInfo(address), this);
voipOneSuccess.mark();
voipOneSuccessHistogram.update(System.currentTimeMillis() - task.getScheduledTime());
}
@ -72,6 +84,7 @@ public class ApnFallbackManager implements Managed, Runnable {
Entry<WebsocketAddress, ApnFallbackTask> taskEntry = taskQueue.get();
ApnFallbackTask task = taskEntry.getValue();
pubSubManager.unsubscribe(new WebSocketConnectionInfo(taskEntry.getKey()), this);
pushServiceClient.send(new ApnMessage(task.getMessage(), task.getApnId(),
false, ApnMessage.MAX_EXPIRATION));
} catch (Throwable e) {
@ -80,6 +93,31 @@ public class ApnFallbackManager implements Managed, Runnable {
}
}
@Override
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage notification = PubSubMessage.parseFrom(message);
if (notification.getType().getNumber() == PubSubMessage.Type.CONNECTED_VALUE) {
WebSocketConnectionInfo address = new WebSocketConnectionInfo(channel);
cancel(address.getWebsocketAddress());
} else {
logger.warn("Got strange pubsub type: " + notification.getType().getNumber());
}
} catch (WebSocketConnectionInfo.FormattingException e) {
logger.warn("Bad formatting?", e);
} catch (InvalidProtocolBufferException e) {
logger.warn("Bad protobuf", e);
}
}
@Override
public void onDispatchSubscribed(String channel) {}
@Override
public void onDispatchUnsubscribed(String channel) {}
public static class ApnFallbackTask {
private final long delay;
@ -147,10 +185,12 @@ public class ApnFallbackManager implements Managed, Runnable {
}
}
public void put(WebsocketAddress address, ApnFallbackTask task) {
public boolean put(WebsocketAddress address, ApnFallbackTask task) {
synchronized (tasks) {
tasks.put(address, task);
ApnFallbackTask previous = tasks.put(address, task);
tasks.notifyAll();
return previous == null;
}
}

View File

@ -0,0 +1,5 @@
package org.whispersystems.textsecuregcm.storage;
public interface PubSubAddress {
public String serialize();
}

View File

@ -1,13 +1,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.concurrent.atomic.AtomicInteger;
import io.dropwizard.lifecycle.Managed;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
@ -21,7 +17,7 @@ public class PubSubManager implements Managed {
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final DispatchManager dispatchManager;
private final JedisPool jedisPool;
private final JedisPool jedisPool;
private boolean subscribed = false;
@ -49,21 +45,19 @@ public class PubSubManager implements Managed {
dispatchManager.shutdown();
}
public void subscribe(WebsocketAddress address, DispatchChannel channel) {
String serializedAddress = address.serialize();
dispatchManager.subscribe(serializedAddress, channel);
public void subscribe(PubSubAddress address, DispatchChannel channel) {
dispatchManager.subscribe(address.serialize(), channel);
}
public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) {
String serializedAddress = address.serialize();
dispatchManager.unsubscribe(serializedAddress, dispatchChannel);
public void unsubscribe(PubSubAddress address, DispatchChannel dispatchChannel) {
dispatchManager.unsubscribe(address.serialize(), dispatchChannel);
}
public boolean hasLocalSubscription(WebsocketAddress address) {
public boolean hasLocalSubscription(PubSubAddress address) {
return dispatchManager.hasSubscription(address.serialize());
}
public boolean publish(WebsocketAddress address, PubSubMessage message) {
public boolean publish(PubSubAddress address, PubSubMessage message) {
return publish(address.serialize().getBytes(), message);
}

View File

@ -162,6 +162,10 @@ public final class PubSubProtos {
* <code>CLOSE = 4;</code>
*/
CLOSE(4, 4),
/**
* <code>CONNECTED = 5;</code>
*/
CONNECTED(5, 5),
;
/**
@ -184,6 +188,10 @@ public final class PubSubProtos {
* <code>CLOSE = 4;</code>
*/
public static final int CLOSE_VALUE = 4;
/**
* <code>CONNECTED = 5;</code>
*/
public static final int CONNECTED_VALUE = 5;
public final int getNumber() { return value; }
@ -195,6 +203,7 @@ public final class PubSubProtos {
case 2: return DELIVER;
case 3: return KEEPALIVE;
case 4: return CLOSE;
case 5: return CONNECTED;
default: return null;
}
}
@ -620,13 +629,13 @@ public final class PubSubProtos {
descriptor;
static {
java.lang.String[] descriptorData = {
"\n\023PubSubMessage.proto\022\ntextsecure\"\230\001\n\rPu" +
"\n\023PubSubMessage.proto\022\ntextsecure\"\247\001\n\rPu" +
"bSubMessage\022,\n\004type\030\001 \001(\0162\036.textsecure.P" +
"ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"H\n\004Ty" +
"ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"W\n\004Ty" +
"pe\022\013\n\007UNKNOWN\020\000\022\014\n\010QUERY_DB\020\001\022\013\n\007DELIVER" +
"\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004B8\n(org.whis" +
"persystems.textsecuregcm.storageB\014PubSub" +
"Protos"
"\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004\022\r\n\tCONNECTE" +
"D\020\005B8\n(org.whispersystems.textsecuregcm." +
"storageB\014PubSubProtos"
};
com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner =
new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() {

View File

@ -20,6 +20,7 @@ import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.TimeUnit;
@ -129,6 +130,10 @@ public class Util {
}
}
public static int hashCode(Object... objects) {
return Arrays.hashCode(objects);
}
public static long todayInMillis() {
return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(System.currentTimeMillis()));
}

View File

@ -13,6 +13,8 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
@ -47,15 +49,16 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final long connectTime = System.currentTimeMillis();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient());
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final long connectTime = System.currentTimeMillis();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient());
apnFallbackManager.cancel(address);
pubSubManager.publish(info, PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED).build());
updateLastSeen(account, device);
pubSubManager.subscribe(address, connection);

View File

@ -0,0 +1,58 @@
package org.whispersystems.textsecuregcm.websocket;
import org.whispersystems.textsecuregcm.storage.PubSubAddress;
import org.whispersystems.textsecuregcm.util.Util;
public class WebSocketConnectionInfo implements PubSubAddress {
private final WebsocketAddress address;
public WebSocketConnectionInfo(WebsocketAddress address) {
this.address = address;
}
public WebSocketConnectionInfo(String serialized) throws FormattingException {
String[] parts = serialized.split("[:]", 3);
if (parts.length != 3 || !"c".equals(parts[2])) {
throw new FormattingException("Bad address: " + serialized);
}
try {
this.address = new WebsocketAddress(parts[0], Long.parseLong(parts[1]));
} catch (NumberFormatException e) {
throw new FormattingException(e);
}
}
public String serialize() {
return address.serialize() + ":c";
}
public WebsocketAddress getWebsocketAddress() {
return address;
}
@Override
public boolean equals(Object other) {
return
other != null &&
other instanceof WebSocketConnectionInfo
&& ((WebSocketConnectionInfo)other).address.equals(address);
}
@Override
public int hashCode() {
return Util.hashCode(address, "c");
}
public static class FormattingException extends Exception {
public FormattingException(String message) {
super(message);
}
public FormattingException(Exception e) {
super(e);
}
}
}

View File

@ -1,6 +1,8 @@
package org.whispersystems.textsecuregcm.websocket;
public class WebsocketAddress {
import org.whispersystems.textsecuregcm.storage.PubSubAddress;
public class WebsocketAddress implements PubSubAddress {
private final String number;
private final long deviceId;

View File

@ -6,7 +6,10 @@ import org.whispersystems.textsecuregcm.entities.ApnMessage;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager.ApnFallbackTask;
import org.whispersystems.textsecuregcm.push.PushServiceClient;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnectionInfo;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import static org.junit.Assert.assertEquals;
@ -18,11 +21,13 @@ public class ApnFallbackManagerTest {
@Test
public void testFullFallback() throws Exception {
PushServiceClient pushServiceClient = mock(PushServiceClient.class);
WebsocketAddress address = mock(WebsocketAddress.class );
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222223", 1L);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 1111);
ApnFallbackTask task = new ApnFallbackTask("foo", message, 500);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
apnFallbackManager.start();
apnFallbackManager.schedule(address, task);
@ -31,6 +36,7 @@ public class ApnFallbackManagerTest {
ArgumentCaptor<ApnMessage> captor = ArgumentCaptor.forClass(ApnMessage.class);
verify(pushServiceClient, times(1)).send(captor.capture());
verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager));
assertEquals(captor.getValue().getMessage(), message.getMessage());
assertEquals(captor.getValue().getApnId(), task.getApnId());
@ -41,15 +47,22 @@ public class ApnFallbackManagerTest {
@Test
public void testNoFallback() throws Exception {
PushServiceClient pushServiceClient = mock(PushServiceClient.class);
WebsocketAddress address = mock(WebsocketAddress.class );
PubSubManager pubSubManager = mock(PubSubManager.class);
WebsocketAddress address = new WebsocketAddress("+14152222222", 1);
WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
ApnMessage message = new ApnMessage("bar", "123", 1, "hmm", true, 5555);
ApnFallbackTask task = new ApnFallbackTask ("foo", message, 500);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushServiceClient, pubSubManager);
apnFallbackManager.start();
apnFallbackManager.schedule(address, task);
apnFallbackManager.cancel(address);
apnFallbackManager.onDispatchMessage(info.serialize(),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.CONNECTED)
.build().toByteArray());
verify(pubSubManager).unsubscribe(eq(info), eq(apnFallbackManager));
Util.sleep(1100);