From 26870d134f43fa8a4b215740b67ab8aad3a7893f Mon Sep 17 00:00:00 2001 From: Ehren Kret Date: Tue, 27 Oct 2020 20:42:21 -0500 Subject: [PATCH] Set source UUID when delivering envelopes from message cache/db on websocket --- .../websocket/WebSocketConnection.java | 3 + .../websocket/WebSocketConnectionTest.java | 78 +++++++++++++++++-- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 8dbcb7e2f..9592b7e17 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -181,6 +181,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac if (!Util.isEmpty(message.getSource())) { builder.setSource(message.getSource()) .setSourceDevice(message.getSourceDevice()); + if (message.getSourceUuid() != null) { + builder.setSourceUuid(message.getSourceUuid().toString()); + } } if (message.getMessage() != null) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 14c89ffa3..dc90af5e9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,8 +1,10 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.auth.basic.BasicCredentials; import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; @@ -42,6 +44,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyInt; @@ -61,13 +64,24 @@ public class WebSocketConnectionTest { private static final String VALID_PASSWORD = "secure"; private static final String INVALID_PASSWORD = "insecure"; - private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class); - private static final AccountsManager accountsManager = mock(AccountsManager.class); - private static final Account account = mock(Account.class ); - private static final Device device = mock(Device.class ); - private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); - private static final ReceiptSender receiptSender = mock(ReceiptSender.class); - private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); + private AccountAuthenticator accountAuthenticator; + private AccountsManager accountsManager; + private Account account; + private Device device; + private UpgradeRequest upgradeRequest; + private ReceiptSender receiptSender; + private ApnFallbackManager apnFallbackManager; + + @Before + public void setup() { + accountAuthenticator = mock(AccountAuthenticator.class); + accountsManager = mock(AccountsManager.class); + account = mock(Account.class); + device = mock(Device.class); + upgradeRequest = mock(UpgradeRequest.class); + receiptSender = mock(ReceiptSender.class); + apnFallbackManager = mock(ApnFallbackManager.class); + } @Test public void testCredentials() throws Exception { @@ -437,6 +451,56 @@ public class WebSocketConnectionTest { verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } + @Test(timeout = 5000L) + public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final UUID senderUuid = UUID.randomUUID(); + final List messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first")); + final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch sendLatch = new CountDownLatch(messages.size()); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer(invocation -> { + sendLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + connection.processStoredMessages(); + + sendLatch.await(); + + verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), argThat(argument -> { + if (argument.isEmpty()) { + return false; + } + + final byte[] body = argument.get(); + try { + final Envelope envelope = Envelope.parseFrom(body); + if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) { + return false; + } + return envelope.getSourceUuid().equals(senderUuid.toString()); + } catch (InvalidProtocolBufferException e) { + return false; + } + })); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + @Test public void testProcessStoredMessagesSingleEmptyCall() { final MessagesManager messagesManager = mock(MessagesManager.class);