Set source UUID when delivering envelopes from message cache/db on websocket

This commit is contained in:
Ehren Kret 2020-10-27 20:42:21 -05:00 committed by Jon Chambers
parent fb2baad7cc
commit 26870d134f
2 changed files with 74 additions and 7 deletions

View File

@ -181,6 +181,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
if (!Util.isEmpty(message.getSource())) {
builder.setSource(message.getSource())
.setSourceDevice(message.getSourceDevice());
if (message.getSourceUuid() != null) {
builder.setSourceUuid(message.getSourceUuid().toString());
}
}
if (message.getMessage() != null) {

View File

@ -1,8 +1,10 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.invocation.InvocationOnMock;
@ -42,6 +44,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
@ -61,13 +64,24 @@ public class WebSocketConnectionTest {
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private AccountAuthenticator accountAuthenticator;
private AccountsManager accountsManager;
private Account account;
private Device device;
private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender;
private ApnFallbackManager apnFallbackManager;
@Before
public void setup() {
accountAuthenticator = mock(AccountAuthenticator.class);
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class);
apnFallbackManager = mock(ApnFallbackManager.class);
}
@Test
public void testCredentials() throws Exception {
@ -437,6 +451,56 @@ public class WebSocketConnectionTest {
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
@Test(timeout = 5000L)
public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage);
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final CountDownLatch sendLatch = new CountDownLatch(messages.size());
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer(invocation -> {
sendLatch.countDown();
return CompletableFuture.completedFuture(successResponse);
});
connection.processStoredMessages();
sendLatch.await();
verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), argThat(argument -> {
if (argument.isEmpty()) {
return false;
}
final byte[] body = argument.get();
try {
final Envelope envelope = Envelope.parseFrom(body);
if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) {
return false;
}
return envelope.getSourceUuid().equals(senderUuid.toString());
} catch (InvalidProtocolBufferException e) {
return false;
}
}));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
@Test
public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class);