diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 9bd38aaf8..b8b92c102 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -29,6 +29,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil; import org.whispersystems.textsecuregcm.util.Util; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.WebSocketClient; @@ -58,10 +59,14 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent")); private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures")); private static final Meter clientNonSuccessResponseMeter = metricRegistry.meter(name(WebSocketConnection.class, "clientNonSuccessResponse")); + private static final Meter discardedMessagesMeter = metricRegistry.meter(name(WebSocketConnection.class, "discardedMessages")); private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement"); private static final String DISPLACEMENT_PLATFORM_TAG_NAME = "platform"; + @VisibleForTesting + static final int MAX_DESKTOP_MESSAGE_SIZE = 1024 * 1024; + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private final ReceiptSender receiptSender; @@ -71,6 +76,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final Device device; private final WebSocketClient client; + private final boolean isDesktopClient; + private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); private final AtomicReference storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); @@ -92,6 +99,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac this.account = account; this.device = device; this.client = client; + + Optional maybePlatform; + + try { + maybePlatform = Optional.of(UserAgentUtil.parseUserAgentString(client.getUserAgent()).getPlatform()); + } catch (final UnrecognizedUserAgentException e) { + maybePlatform = Optional.empty(); + } + + this.isDesktopClient = maybePlatform.map(platform -> platform == ClientPlatform.DESKTOP).orElse(false); } public void start() { @@ -211,7 +228,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac builder.setRelay(message.getRelay()); } - sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); + final Envelope envelope = builder.build(); + + if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { + messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), message.getId(), message.isCached()); + discardedMessagesMeter.mark(); + + sendFutures[i] = CompletableFuture.completedFuture(null); + } else { + sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); + } } CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index cd8bf82e6..bb4959966 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.auth.basic.BasicCredentials; +import org.apache.commons.lang3.RandomStringUtils; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Before; import org.junit.Test; @@ -643,6 +644,153 @@ public class WebSocketConnectionTest { verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); } + @Test + public void testDiscardOversizedMessagesForDesktop() { + MessagesManager storedMessages = mock(MessagesManager.class); + + UUID accountUuid = UUID.randomUUID(); + UUID senderOneUuid = UUID.randomUUID(); + UUID senderTwoUuid = UUID.randomUUID(); + + List outgoingMessages = new LinkedList () {{ + add(createMessage(1L, false, "sender1", senderOneUuid, 1111, false, "first")); + add(createMessage(2L, false, "sender1", senderOneUuid, 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); + add(createMessage(3L, false, "sender2", senderTwoUuid, 3333, false, "third")); + }}; + + OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); + + when(device.getId()).thenReturn(2L); + when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52])); + + when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); + + final Device sender1device = mock(Device.class); + + Set sender1devices = new HashSet<>() {{ + add(sender1device); + }}; + + Account sender1 = mock(Account.class); + when(sender1.getDevices()).thenReturn(sender1devices); + + when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); + when(accountsManager.get("sender2")).thenReturn(Optional.empty()); + + String userAgent = "Signal-Desktop/1.2.3"; + + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + .thenReturn(outgoingMessagesList); + + final List> futures = new LinkedList<>(); + final WebSocketClient client = mock(WebSocketClient.class); + + when(client.getUserAgent()).thenReturn(userAgent); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) + .thenAnswer(new Answer>() { + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + } + }); + + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client); + + connection.start(); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + + assertEquals(2, futures.size()); + + WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); + when(response.getStatus()).thenReturn(200); + futures.get(0).complete(response); + futures.get(1).complete(response); + + // We should delete all three messages even though we only sent two; one got discarded because it was too big for + // desktop clients. + verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), anyLong(), anyBoolean()); + + connection.stop(); + verify(client).close(anyInt(), anyString()); + } + + @Test + public void testSendOversizedMessagesForNonDesktop() throws Exception { + MessagesManager storedMessages = mock(MessagesManager.class); + + UUID accountUuid = UUID.randomUUID(); + UUID senderOneUuid = UUID.randomUUID(); + UUID senderTwoUuid = UUID.randomUUID(); + + List outgoingMessages = new LinkedList () {{ + add(createMessage(1L, false, "sender1", senderOneUuid, 1111, false, "first")); + add(createMessage(2L, false, "sender1", senderOneUuid, 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); + add(createMessage(3L, false, "sender2", senderTwoUuid, 3333, false, "third")); + }}; + + OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); + + when(device.getId()).thenReturn(2L); + when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52])); + + when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); + + final Device sender1device = mock(Device.class); + + Set sender1devices = new HashSet<>() {{ + add(sender1device); + }}; + + Account sender1 = mock(Account.class); + when(sender1.getDevices()).thenReturn(sender1devices); + + when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); + when(accountsManager.get("sender2")).thenReturn(Optional.empty()); + + String userAgent = "Signal-Android/4.68.3"; + + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + .thenReturn(outgoingMessagesList); + + final List> futures = new LinkedList<>(); + final WebSocketClient client = mock(WebSocketClient.class); + + when(client.getUserAgent()).thenReturn(userAgent); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) + .thenAnswer(new Answer>() { + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + } + }); + + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client); + + connection.start(); + verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + + assertEquals(3, futures.size()); + + WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); + when(response.getStatus()).thenReturn(200); + futures.get(0).complete(response); + futures.get(1).complete(response); + futures.get(2).complete(response); + + verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), anyLong(), anyBoolean()); + + connection.stop(); + verify(client).close(anyInt(), anyString()); + } + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);