Disconnect sockets on other servers when new websocket comes in

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2017-08-18 15:58:12 -07:00
parent 322548f078
commit 789f11a5c4
4 changed files with 30 additions and 14 deletions

View File

@ -96,7 +96,7 @@
<dependency>
<groupId>org.whispersystems</groupId>
<artifactId>websocket-resources</artifactId>
<version>0.5.2</version>
<version>0.5.3</version>
</dependency>
<dependency>
<groupId>org.whispersystems</groupId>

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
@ -18,6 +19,8 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.security.SecureRandom;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
@ -45,16 +48,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final Timer.Context timer = durationTimer.time();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient());
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final String connectionId = String.valueOf(new SecureRandom().nextLong());
final Timer.Context timer = durationTimer.time();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient(), connectionId);
final PubSubMessage connectMessage = PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED)
.setContent(ByteString.copyFrom(connectionId.getBytes()))
.build();
pubSubManager.publish(info, PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED).build());
pubSubManager.publish(info, connectMessage);
pubSubManager.publish(address, connectMessage);
pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {

View File

@ -53,13 +53,15 @@ public class WebSocketConnection implements DispatchChannel {
private final Account account;
private final Device device;
private final WebSocketClient client;
private final String connectionId;
public WebSocketConnection(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
Account account,
Device device,
WebSocketClient client)
WebSocketClient client,
String connectionId)
{
this.pushSender = pushSender;
this.receiptSender = receiptSender;
@ -67,6 +69,7 @@ public class WebSocketConnection implements DispatchChannel {
this.account = account;
this.device = device;
this.client = client;
this.connectionId = connectionId;
}
@Override
@ -81,6 +84,11 @@ public class WebSocketConnection implements DispatchChannel {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.<Long>absent(), false);
break;
case PubSubMessage.Type.CONNECTED_VALUE:
if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) {
client.hardDisconnectQuietly();
}
break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}

View File

@ -143,7 +143,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client);
account, device, client, "someid");
connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
@ -227,7 +227,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client);
account, device, client, "anotherid");
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
@ -333,7 +333,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client);
account, device, client, "onemoreid");
connection.onDispatchSubscribed(websocketAddress.serialize());