Send push notifications if websockets close before all messages are delivered

This commit is contained in:
Jon Chambers 2020-10-27 16:02:55 -04:00 committed by GitHub
parent ae566dca98
commit 05d9ec673e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 9 deletions

View File

@ -423,7 +423,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, apnFallbackManager, clientPresenceManager));
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
webSocketEnvironment.jersey().register(messageController);

View File

@ -138,21 +138,24 @@ public class MessageSender implements Managed {
throw new AssertionError();
}
final boolean clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());
final boolean clientPresent;
if (online) {
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());
if (clientPresent) {
messagesManager.insertEphemeral(account.getUuid(), device.getId(), message);
}
} else {
messagesManager.insert(account.getUuid(), device.getId(), message);
// We check for client presence after inserting the message to take a conservative view of notifications. If the
// client wasn't present at the time of insertion but is now, they'll retrieve the message. If they were present
// but disconnected before the message was delivered, we should send a notification.
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());
if (!clientPresent) {
if (!Util.isEmpty(device.getGcmId())) {
sendGcmNotification(account, device);
} else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) {
sendApnNotification(account, device);
}
sendNewMessageNotification(account, device);
}
}
@ -164,6 +167,14 @@ public class MessageSender implements Managed {
Metrics.counter(SEND_COUNTER_NAME, tags).increment();
}
public void sendNewMessageNotification(final Account account, final Device device) {
if (!Util.isEmpty(device.getGcmId())) {
sendGcmNotification(account, device);
} else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) {
sendApnNotification(account, device);
}
}
private void sendGcmNotification(Account account, Device device) {
GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(),
(int)device.getId(), GcmMessage.Type.NOTIFICATION, Optional.empty());

View File

@ -41,6 +41,7 @@ public class Messages {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer storeTimer = metricRegistry.timer(name(Messages.class, "store" ));
private final Timer loadTimer = metricRegistry.timer(name(Messages.class, "load" ));
private final Timer hasMessagesTimer = metricRegistry.timer(name(Messages.class, "hasMessages" ));
private final Timer removeBySourceTimer = metricRegistry.timer(name(Messages.class, "removeBySource"));
private final Timer removeByGuidTimer = metricRegistry.timer(name(Messages.class, "removeByGuid" ));
private final Timer removeByIdTimer = metricRegistry.timer(name(Messages.class, "removeById" ));

View File

@ -220,6 +220,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return removedMessages;
}
public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) {
return redisCluster.withBinaryCluster(connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0);
}
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> {

View File

@ -51,6 +51,10 @@ public class MessagesManager {
return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice);
}
public boolean hasCachedMessages(final UUID destinationUuid, final long destinationDevice) {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));

View File

@ -6,6 +6,7 @@ import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
@ -26,16 +27,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final MessageSender messageSender;
private final ApnFallbackManager apnFallbackManager;
private final ClientPresenceManager clientPresenceManager;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager,
ApnFallbackManager apnFallbackManager,
final MessageSender messageSender, ApnFallbackManager apnFallbackManager,
ClientPresenceManager clientPresenceManager)
{
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageSender = messageSender;
this.apnFallbackManager = apnFallbackManager;
this.clientPresenceManager = clientPresenceManager;
}
@ -66,6 +69,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
openWebsocketCounter.dec();
timer.stop();
if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) {
messageSender.sendNewMessageNotification(account, device);
}
}
});
} else {

View File

@ -160,6 +160,17 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@Test
public void testHasMessages() {
assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
}
@Test
@Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) {

View File

@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -72,7 +73,7 @@ public class WebSocketConnectionTest {
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, apnFallbackManager, mock(ClientPresenceManager.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))