Include server GUID when sending messages over websocket

This commit is contained in:
Chris Eager 2021-05-14 17:01:22 -05:00 committed by Chris Eager
parent d59eabd9d7
commit 00c9023e74
2 changed files with 33 additions and 5 deletions

View File

@ -298,6 +298,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
builder.setRelay(message.getRelay());
}
builder.setServerGuid(message.getGuid().toString());
final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
@ -31,6 +32,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.After;
import org.junit.Before;
@ -39,6 +41,7 @@ import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
@ -122,7 +125,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope.toBuilder().clearServerGuid().build());
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
@ -133,7 +136,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope.toBuilder().clearServerGuid().build());
expectedMessages.add(envelope);
}
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@ -185,11 +188,15 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
persistedMessages.add(generateRandomMessage(UUID.randomUUID()));
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
@ -197,15 +204,34 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), generateRandomMessage(messageGuid));
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope);
}
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed")));
webSocketConnection.processStoredMessages();
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any());
//noinspection unchecked
ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = messageBodyCaptor.getAllValues().stream()
.map(Optional::get)
.map(messageBytes -> {
try {
return Envelope.parseFrom(messageBytes);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
assertTrue(expectedMessages.containsAll(sentMessages));
}
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {