diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 80b7ec758..658ea79af 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -12,7 +12,6 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.ScoredValue; import io.lettuce.core.ScriptOutputType; -import io.lettuce.core.SetArgs; import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 10f943af2..0c915810e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -23,12 +23,10 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; -import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; -import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; @@ -139,10 +137,12 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { if (authenticated) { final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class); - final Device device = auth.getAuthenticatedDevice(); final Timer.Sample sample = Timer.start(); final WebSocketConnection connection = new WebSocketConnection(receiptSender, - messagesManager, messageMetrics, auth, device, + messagesManager, + messageMetrics, + pushNotificationManager, + auth, context.getClient(), scheduledExecutorService, messageDeliveryScheduler, @@ -150,8 +150,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketAtomicInteger.incrementAndGet(); - pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, userAgent); - final AtomicReference> renewPresenceFutureReference = new AtomicReference<>(); context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { @@ -164,29 +162,41 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { renewPresenceFuture.cancel(false); } + // We begin the shutdown process by removing this client's "presence," which means it will again begin to + // receive push notifications for inbound messages. We should do this first because, at this point, the + // connection has already closed and attempts to actually deliver a message via the connection will not succeed. + // It's preferable to start sending push notifications as soon as possible. + RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection)); + + // Next, we stop listening for inbound messages. If a message arrives after this call, the websocket connection + // will not be notified and will not change its state, but that's okay because it has already closed and + // attempts to deliver mesages via this connection will not succeed. + RedisOperation.unchecked(() -> messagesManager.removeMessageAvailabilityListener(connection)); + + // Finally, stop trying to deliver messages and send a push notification if the connection is aware of any + // undelivered messages. connection.stop(); - - RedisOperation.unchecked( - () -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId(), connection)); - RedisOperation.unchecked(() -> { - messagesManager.removeMessageAvailabilityListener(connection); - - if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) { - try { - pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true); - } catch (NotPushRegisteredException ignored) { - } - } - }); }); try { + // Once we add this connection as a message availability listener, it will be notified any time a new message + // arrives in the message cache. This updates the connection's "may have messages" state. It's important that + // we do this first because we want to make sure we're accurately tracking message availability in the + // connection's internal state. + messagesManager.addMessageAvailabilityListener(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection); + + // Once we "start" the websocket connection, we'll cancel any scheduled "you may have new messages" push + // notifications and begin delivering any stored messages for the connected device. We have not yet declared the + // client as "present" yet. If a message arrives at this point, we will update the message availability state + // correctly, but we may also send a spurious push notification. connection.start(); - clientPresenceManager.setPresent(auth.getAccount().getUuid(), device.getId(), connection); - messagesManager.addMessageAvailabilityListener(auth.getAccount().getUuid(), device.getId(), connection); + + // Finally, we register this client's presence, which suppresses push notifications. We do this last because + // receiving extra push notifications is generally preferable to missing out on a push notification. + clientPresenceManager.setPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection); renewPresenceFutureReference.set(scheduledExecutorService.scheduleAtFixedRate(() -> RedisOperation.unchecked(() -> - clientPresenceManager.renewPresence(auth.getAccount().getUuid(), device.getId())), + clientPresenceManager.renewPresence(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())), RENEW_PRESENCE_INTERVAL_MINUTES, RENEW_PRESENCE_INTERVAL_MINUTES, TimeUnit.MINUTES)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index bbc552cb8..52fb97934 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -43,6 +43,8 @@ import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -86,6 +88,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac "sendMessages"); private static final String SEND_MESSAGE_ERROR_COUNTER = MetricsUtil.name(WebSocketConnection.class, "sendMessageError"); + private static final String PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME = + MetricsUtil.name(WebSocketConnection.class, "pushNotificationOnClose"); private static final String STATUS_CODE_TAG = "status"; private static final String STATUS_MESSAGE_TAG = "message"; private static final String ERROR_TYPE_TAG = "errorType"; @@ -110,9 +114,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final ReceiptSender receiptSender; private final MessagesManager messagesManager; private final MessageMetrics messageMetrics; + private final PushNotificationManager pushNotificationManager; private final AuthenticatedAccount auth; - private final Device device; private final WebSocketClient client; private final int sendFuturesTimeoutMillis; @@ -143,8 +147,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac public WebSocketConnection(ReceiptSender receiptSender, MessagesManager messagesManager, MessageMetrics messageMetrics, + PushNotificationManager pushNotificationManager, AuthenticatedAccount auth, - Device device, WebSocketClient client, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, @@ -153,8 +157,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this(receiptSender, messagesManager, messageMetrics, + pushNotificationManager, auth, - device, client, DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, scheduledExecutorService, @@ -166,8 +170,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac WebSocketConnection(ReceiptSender receiptSender, MessagesManager messagesManager, MessageMetrics messageMetrics, + PushNotificationManager pushNotificationManager, AuthenticatedAccount auth, - Device device, WebSocketClient client, int sendFuturesTimeoutMillis, ScheduledExecutorService scheduledExecutorService, @@ -177,8 +181,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; + this.pushNotificationManager = pushNotificationManager; this.auth = auth; - this.device = device; this.client = client; this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis; this.scheduledExecutorService = scheduledExecutorService; @@ -187,6 +191,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } public void start() { + pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), client.getUserAgent()); queueDrainStartTime.set(System.currentTimeMillis()); processStoredMessages(); } @@ -204,6 +209,17 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } client.close(1000, "OK"); + + if (storedMessageState.get() != StoredMessageState.EMPTY) { + try { + pushNotificationManager.sendNewMessageNotification(auth.getAccount(), auth.getAuthenticatedDevice().getId(), true); + + Metrics.counter(PUSH_NOTIFICATION_ON_CLOSE_COUNTER_NAME, + Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()))) + .increment(); + } catch (NotPushRegisteredException ignored) { + } + } } private CompletableFuture sendMessage(final Envelope message, StoredMessageInfo storedMessageInfo) { @@ -224,7 +240,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } else { messageMetrics.measureOutgoingMessageLatency(message.getServerTimestamp(), "websocket", - device.isPrimary(), + auth.getAuthenticatedDevice().isPrimary(), client.getUserAgent(), clientReleaseManager); } @@ -232,12 +248,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final CompletableFuture result; if (isSuccessResponse(response)) { - result = messagesManager.delete(auth.getAccount().getUuid(), device, + result = messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) .thenApply(ignored -> null); if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { - recordMessageDeliveryDuration(message.getServerTimestamp(), device); + recordMessageDeliveryDuration(message.getServerTimestamp(), auth.getAuthenticatedDevice()); sendDeliveryReceiptFor(message); } } else { @@ -359,7 +375,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private void sendMessages(final boolean cachedMessagesOnly, final CompletableFuture queueCleared) { final Publisher messages = - messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), device, cachedMessagesOnly); + messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly); final AtomicBoolean hasErrored = new AtomicBoolean(); @@ -418,7 +434,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); if (envelope.getStory() && !client.shouldDeliverStories()) { - messagesManager.delete(auth.getAccount().getUuid(), device, messageGuid, envelope.getServerTimestamp()); + messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), messageGuid, envelope.getServerTimestamp()); return CompletableFuture.completedFuture(null); } else { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 4299bf8d4..6a4c3ccd9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -47,12 +47,12 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; +import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.MessagesCache; @@ -126,8 +126,8 @@ class WebSocketConnectionIntegrationTest { mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), + mock(PushNotificationManager.class), new AuthenticatedAccount(account, device), - device, webSocketClient, scheduledExecutorService, messageDeliveryScheduler, @@ -212,8 +212,8 @@ class WebSocketConnectionIntegrationTest { mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), + mock(PushNotificationManager.class), new AuthenticatedAccount(account, device), - device, webSocketClient, scheduledExecutorService, messageDeliveryScheduler, @@ -279,8 +279,8 @@ class WebSocketConnectionIntegrationTest { mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessageMetrics(), + mock(PushNotificationManager.class), new AuthenticatedAccount(account, device), - device, webSocketClient, 100, // use a very short timeout, so that this test completes quickly scheduledExecutorService, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 4800d17be..633da1715 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -626,7 +626,8 @@ class WebSocketConnectionTest { } private @NotNull WebSocketConnection webSocketConnection(final WebSocketClient client) { - return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), auth, device, client, + return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), + mock(PushNotificationManager.class), auth, client, retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager); }