Discard oversized messages bound for desktop clients via websockets.

This commit is contained in:
Jon Chambers 2020-12-07 11:38:56 -05:00 committed by Jon Chambers
parent 3a268aef50
commit 92fde83b3a
2 changed files with 175 additions and 1 deletions

View File

@ -29,6 +29,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.WebSocketClient;
@ -58,10 +59,14 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
private static final Meter clientNonSuccessResponseMeter = metricRegistry.meter(name(WebSocketConnection.class, "clientNonSuccessResponse"));
private static final Meter discardedMessagesMeter = metricRegistry.meter(name(WebSocketConnection.class, "discardedMessages"));
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
private static final String DISPLACEMENT_PLATFORM_TAG_NAME = "platform";
@VisibleForTesting
static final int MAX_DESKTOP_MESSAGE_SIZE = 1024 * 1024;
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final ReceiptSender receiptSender;
@ -71,6 +76,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private final Device device;
private final WebSocketClient client;
private final boolean isDesktopClient;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
@ -92,6 +99,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
this.account = account;
this.device = device;
this.client = client;
Optional<ClientPlatform> maybePlatform;
try {
maybePlatform = Optional.of(UserAgentUtil.parseUserAgentString(client.getUserAgent()).getPlatform());
} catch (final UnrecognizedUserAgentException e) {
maybePlatform = Optional.empty();
}
this.isDesktopClient = maybePlatform.map(platform -> platform == ClientPlatform.DESKTOP).orElse(false);
}
public void start() {
@ -211,7 +228,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
builder.setRelay(message.getRelay());
}
sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())));
final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {
messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), message.getId(), message.isCached());
discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null);
} else {
sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())));
}
}
CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Before;
import org.junit.Test;
@ -643,6 +644,153 @@ public class WebSocketConnectionTest {
verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false);
}
@Test
public void testDiscardOversizedMessagesForDesktop() {
MessagesManager storedMessages = mock(MessagesManager.class);
UUID accountUuid = UUID.randomUUID();
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage(1L, false, "sender1", senderOneUuid, 1111, false, "first"));
add(createMessage(2L, false, "sender1", senderOneUuid, 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)));
add(createMessage(3L, false, "sender2", senderTwoUuid, 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
when(device.getId()).thenReturn(2L);
when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52]));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<>() {{
add(sender1device);
}};
Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.empty());
String userAgent = "Signal-Desktop/1.2.3";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client);
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(2, futures.size());
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(0).complete(response);
futures.get(1).complete(response);
// We should delete all three messages even though we only sent two; one got discarded because it was too big for
// desktop clients.
verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), anyLong(), anyBoolean());
connection.stop();
verify(client).close(anyInt(), anyString());
}
@Test
public void testSendOversizedMessagesForNonDesktop() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
UUID accountUuid = UUID.randomUUID();
UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage(1L, false, "sender1", senderOneUuid, 1111, false, "first"));
add(createMessage(2L, false, "sender1", senderOneUuid, 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)));
add(createMessage(3L, false, "sender2", senderTwoUuid, 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
when(device.getId()).thenReturn(2L);
when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52]));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<>() {{
add(sender1device);
}};
Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.empty());
String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList);
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client);
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(3, futures.size());
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(0).complete(response);
futures.get(1).complete(response);
futures.get(2).complete(response);
verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), anyLong(), anyBoolean());
connection.stop();
verify(client).close(anyInt(), anyString());
}
private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);