diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b836b90c6..ea656351a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -298,6 +298,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac builder.setRelay(message.getRelay()); } + builder.setServerGuid(message.getGuid().toString()); + final Envelope envelope = builder.build(); if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 925efd585..0a11cf507 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.websocket; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; @@ -31,6 +32,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; import org.junit.After; import org.junit.Before; @@ -39,6 +41,7 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; @@ -122,7 +125,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); persistedMessages.add(envelope); - expectedMessages.add(envelope.toBuilder().clearServerGuid().build()); + expectedMessages.add(envelope); } messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); @@ -133,7 +136,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); - expectedMessages.add(envelope.toBuilder().clearServerGuid().build()); + expectedMessages.add(envelope); } final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -185,11 +188,15 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest final int persistedMessageCount = 207; final int cachedMessageCount = 173; + final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); + { final List persistedMessages = new ArrayList<>(persistedMessageCount); for (int i = 0; i < persistedMessageCount; i++) { - persistedMessages.add(generateRandomMessage(UUID.randomUUID())); + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + persistedMessages.add(envelope); + expectedMessages.add(envelope); } messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); @@ -197,15 +204,34 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), generateRandomMessage(messageGuid)); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + + expectedMessages.add(envelope); } when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed"))); webSocketConnection.processStoredMessages(); - verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()); + //noinspection unchecked + ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); + + verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + + final List sentMessages = messageBodyCaptor.getAllValues().stream() + .map(Optional::get) + .map(messageBytes -> { + try { + return Envelope.parseFrom(messageBytes); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + + assertTrue(expectedMessages.containsAll(sentMessages)); } private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {