Send push notifications if websockets close before all messages are delivered

This commit is contained in:
Jon Chambers 2020-10-27 16:02:55 -04:00 committed by GitHub
parent ae566dca98
commit 05d9ec673e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 9 deletions

View File

@ -423,7 +423,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
/// ///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, apnFallbackManager, clientPresenceManager)); webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
webSocketEnvironment.jersey().register(messageController); webSocketEnvironment.jersey().register(messageController);

View File

@ -138,21 +138,24 @@ public class MessageSender implements Managed {
throw new AssertionError(); throw new AssertionError();
} }
final boolean clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId()); final boolean clientPresent;
if (online) { if (online) {
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());
if (clientPresent) { if (clientPresent) {
messagesManager.insertEphemeral(account.getUuid(), device.getId(), message); messagesManager.insertEphemeral(account.getUuid(), device.getId(), message);
} }
} else { } else {
messagesManager.insert(account.getUuid(), device.getId(), message); messagesManager.insert(account.getUuid(), device.getId(), message);
// We check for client presence after inserting the message to take a conservative view of notifications. If the
// client wasn't present at the time of insertion but is now, they'll retrieve the message. If they were present
// but disconnected before the message was delivered, we should send a notification.
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());
if (!clientPresent) { if (!clientPresent) {
if (!Util.isEmpty(device.getGcmId())) { sendNewMessageNotification(account, device);
sendGcmNotification(account, device);
} else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) {
sendApnNotification(account, device);
}
} }
} }
@ -164,6 +167,14 @@ public class MessageSender implements Managed {
Metrics.counter(SEND_COUNTER_NAME, tags).increment(); Metrics.counter(SEND_COUNTER_NAME, tags).increment();
} }
public void sendNewMessageNotification(final Account account, final Device device) {
if (!Util.isEmpty(device.getGcmId())) {
sendGcmNotification(account, device);
} else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) {
sendApnNotification(account, device);
}
}
private void sendGcmNotification(Account account, Device device) { private void sendGcmNotification(Account account, Device device) {
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(), GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), GcmMessage.Type.NOTIFICATION, Optional.empty()); (int)device.getId(), GcmMessage.Type.NOTIFICATION, Optional.empty());

View File

@ -41,6 +41,7 @@ public class Messages {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer storeTimer = metricRegistry.timer(name(Messages.class, "store" )); private final Timer storeTimer = metricRegistry.timer(name(Messages.class, "store" ));
private final Timer loadTimer = metricRegistry.timer(name(Messages.class, "load" )); private final Timer loadTimer = metricRegistry.timer(name(Messages.class, "load" ));
private final Timer hasMessagesTimer = metricRegistry.timer(name(Messages.class, "hasMessages" ));
private final Timer removeBySourceTimer = metricRegistry.timer(name(Messages.class, "removeBySource")); private final Timer removeBySourceTimer = metricRegistry.timer(name(Messages.class, "removeBySource"));
private final Timer removeByGuidTimer = metricRegistry.timer(name(Messages.class, "removeByGuid" )); private final Timer removeByGuidTimer = metricRegistry.timer(name(Messages.class, "removeByGuid" ));
private final Timer removeByIdTimer = metricRegistry.timer(name(Messages.class, "removeById" )); private final Timer removeByIdTimer = metricRegistry.timer(name(Messages.class, "removeById" ));

View File

@ -220,6 +220,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return removedMessages; return removedMessages;
} }
public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) {
return redisCluster.withBinaryCluster(connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0);
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(final UUID destinationUuid, final long destinationDevice, final int limit) { public List<OutgoingMessageEntity> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> { return getMessagesTimer.record(() -> {

View File

@ -51,6 +51,10 @@ public class MessagesManager {
return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice); return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice);
} }
public boolean hasCachedMessages(final UUID destinationUuid, final long destinationDevice) {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));

View File

@ -6,6 +6,7 @@ import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -26,16 +27,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final ReceiptSender receiptSender; private final ReceiptSender receiptSender;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final MessageSender messageSender;
private final ApnFallbackManager apnFallbackManager; private final ApnFallbackManager apnFallbackManager;
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
public AuthenticatedConnectListener(ReceiptSender receiptSender, public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager, MessagesManager messagesManager,
ApnFallbackManager apnFallbackManager, final MessageSender messageSender, ApnFallbackManager apnFallbackManager,
ClientPresenceManager clientPresenceManager) ClientPresenceManager clientPresenceManager)
{ {
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.messageSender = messageSender;
this.apnFallbackManager = apnFallbackManager; this.apnFallbackManager = apnFallbackManager;
this.clientPresenceManager = clientPresenceManager; this.clientPresenceManager = clientPresenceManager;
} }
@ -66,6 +69,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
openWebsocketCounter.dec(); openWebsocketCounter.dec();
timer.stop(); timer.stop();
if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) {
messageSender.sendNewMessageNotification(account, device);
}
} }
}); });
} else { } else {

View File

@ -160,6 +160,17 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
} }
@Test
public void testHasMessages() {
assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
}
@Test @Test
@Parameters({"true", "false"}) @Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) { public void testGetMessages(final boolean sealedSender) {

View File

@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -72,7 +73,7 @@ public class WebSocketConnectionTest {
public void testCredentials() throws Exception { public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class); MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, apnFallbackManager, mock(ClientPresenceManager.class)); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))