diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 06c1c4aef..26cb56838 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -9,6 +9,8 @@ import com.codahale.metrics.Counter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.MessageSender; @@ -30,6 +32,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final Timer unauthenticatedDurationTimer = metricRegistry.timer(name(WebSocketConnection.class, "unauthenticated_connection_duration")); private static final Counter openWebsocketCounter = metricRegistry.counter(name(WebSocketConnection.class, "open_websockets")); + private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class); + private final ReceiptSender receiptSender; private final MessagesManager messagesManager; private final MessageSender messageSender; @@ -61,25 +65,31 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.inc(); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); - clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection); - messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); - connection.start(); - context.addListener(new WebSocketSessionContext.WebSocketEventListener() { @Override public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { - clientPresenceManager.clearPresence(account.getUuid(), device.getId()); - messagesManager.removeMessageAvailabilityListener(connection); - connection.stop(); - openWebsocketCounter.dec(); timer.stop(); - if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) { - messageSender.sendNewMessageNotification(account, device); - } + RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(account.getUuid(), device.getId())); + RedisOperation.unchecked(() -> { + messagesManager.removeMessageAvailabilityListener(connection); + + if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) { + messageSender.sendNewMessageNotification(account, device); + } + }); } }); + + try { + clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection); + messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); + connection.start(); + } catch (final Exception e) { + log.warn("Failed to initialize websocket", e); + context.getClient().close(1011, "Unexpected error initializing connection"); + } } else { final Timer.Context timer = unauthenticatedDurationTimer.time(); context.addListener((context1, statusCode, reason) -> timer.stop());