From 3636626e09d6a251a1b447dd67bf5da9fae0761f Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 27 Jul 2022 15:43:39 -0400 Subject: [PATCH] Make `Envelope` the main unit of currency when working with stored messages --- .../controllers/MessageController.java | 47 ++-- .../textsecuregcm/storage/MessagesCache.java | 32 +-- .../storage/MessagesDynamoDb.java | 43 ++-- .../storage/MessagesManager.java | 13 +- .../websocket/WebSocketConnection.java | 20 +- .../controllers/MessageControllerTest.java | 69 ++++-- .../storage/MessagesCacheTest.java | 19 +- .../tests/storage/MessagesDynamoDbTest.java | 76 ++----- .../websocket/WebSocketConnectionTest.java | 204 ++++++++---------- 9 files changed, 245 insertions(+), 278 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index e40685ceb..195d5d5f4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -21,7 +21,6 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -414,11 +413,19 @@ public class MessageController { RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice())); } - final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice( - auth.getAccount().getUuid(), - auth.getAuthenticatedDevice().getId(), - userAgent, - false); + final OutgoingMessageEntityList outgoingMessages; + { + final Pair, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice( + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + userAgent, + false); + + outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream() + .map(OutgoingMessageEntity::fromEnvelope) + .collect(Collectors.toList()), + messagesAndHasMore.second()); + } { String platform; @@ -450,24 +457,22 @@ public class MessageController { @DELETE @Path("/uuid/{uuid}") public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) { - try { - Optional message = messagesManager.delete( - auth.getAccount().getUuid(), - auth.getAuthenticatedDevice().getId(), - uuid, - null); + messagesManager.delete( + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + uuid, + null).ifPresent(deletedMessage -> { - if (message.isPresent()) { - WebSocketConnection.recordMessageDeliveryDuration(message.get().timestamp(), auth.getAuthenticatedDevice()); - if (!Util.isEmpty(message.get().source()) - && message.get().type() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { - receiptSender.sendReceipt(auth, message.get().sourceUuid(), message.get().timestamp()); + WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), auth.getAuthenticatedDevice()); + + if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { + try { + receiptSender.sendReceipt(auth, UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); + } catch (Exception e) { + logger.warn("Failed to send delivery receipt", e); } } - - } catch (NoSuchUserException e) { - logger.warn("Sending delivery receipt", e); - } + }); } @Timed diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 7f4cd7aa0..8ce563544 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -40,7 +40,6 @@ import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; @@ -148,13 +147,13 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp guid.toString().getBytes(StandardCharsets.UTF_8)))); } - public Optional remove(final UUID destinationUuid, final long destinationDevice, + public Optional remove(final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) { return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst(); } @SuppressWarnings("unchecked") - public List remove(final UUID destinationUuid, final long destinationDevice, + public List remove(final UUID destinationUuid, final long destinationDevice, final List messageGuids) { final List serialized = (List) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> @@ -164,11 +163,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)) .collect(Collectors.toList()))); - final List removedMessages = new ArrayList<>(serialized.size()); + final List removedMessages = new ArrayList<>(serialized.size()); for (final byte[] bytes : serialized) { try { - removedMessages.add(constructEntityFromEnvelope(MessageProtos.Envelope.parseFrom(bytes))); + removedMessages.add(MessageProtos.Envelope.parseFrom(bytes)); } catch (final InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); } @@ -183,7 +182,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @SuppressWarnings("unchecked") - public List get(final UUID destinationUuid, final long destinationDevice, final int limit) { + public List get(final UUID destinationUuid, final long destinationDevice, final int limit) { return getMessagesTimer.record(() -> { final List queueItems = (List) getItemsScript.executeBinary( List.of(getMessageQueueKey(destinationUuid, destinationDevice), @@ -193,7 +192,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp final long earliestAllowableEphemeralTimestamp = System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); - final List messageEntities; + final List messageEntities; final List staleEphemeralMessageGuids = new ArrayList<>(); if (queueItems.size() % 2 == 0) { @@ -207,9 +206,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp continue; } - final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8)); - - messageEntities.add(constructEntityFromEnvelope(message)); + messageEntities.add(message); } catch (InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); } @@ -379,21 +376,6 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } } - @VisibleForTesting - static OutgoingMessageEntity constructEntityFromEnvelope(MessageProtos.Envelope envelope) { - return new OutgoingMessageEntity( - envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null, - envelope.getType().getNumber(), - envelope.getTimestamp(), - envelope.getSource(), - envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, - envelope.getSourceDevice(), - envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null, - envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, - envelope.hasContent() ? envelope.getContent().toByteArray() : null, - envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0); - } - @VisibleForTesting static String getQueueName(final UUID accountUuid, final long deviceId) { return accountUuid + "::" + deviceId; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index 1f214c401..fcf7fe1ac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -112,7 +112,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems)); } - public List load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) { + public List load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) { return loadTimer.record(() -> { final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE); final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); @@ -128,9 +128,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) .limit(numberOfMessagesToFetch) .build(); - List messageEntities = new ArrayList<>(numberOfMessagesToFetch); + List messageEntities = new ArrayList<>(numberOfMessagesToFetch); for (Map message : db().queryPaginator(queryRequest).items()) { - messageEntities.add(convertItemToOutgoingMessageEntity(message)); + messageEntities.add(convertItemToEnvelope(message)); if (messageEntities.size() == numberOfMessagesToFetch) { // queryPaginator() uses limit() as the page size, not as an absolute limit // …but a page might be smaller than limit, because a page is capped at 1 MB @@ -141,7 +141,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { }); } - public Optional deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid, + public Optional deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid, final UUID messageUuid) { return deleteByGuid.record(() -> { final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); @@ -162,7 +162,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { }); } - public Optional deleteMessage(final UUID destinationAccountUuid, + public Optional deleteMessage(final UUID destinationAccountUuid, final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { return deleteByKey.record(() -> { final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); @@ -173,7 +173,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .returnValues(ReturnValue.ALL_OLD); final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - return Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes())); + return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); } return Optional.empty(); @@ -181,8 +181,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } @Nonnull - private Optional deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) { - Optional result = Optional.empty(); + private Optional deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) { + Optional result = Optional.empty(); for (Map item : db().queryPaginator(queryRequest).items()) { final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray(); DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() @@ -193,7 +193,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - result = Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes())); + result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); } } return result; @@ -233,19 +233,20 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { }); } - private OutgoingMessageEntity convertItemToOutgoingMessageEntity(Map message) { - final SortKey sortKey = convertSortKey(message.get(KEY_SORT).b().asByteArray()); - final UUID messageUuid = convertLocalIndexMessageUuidSortKey(message.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray()); - final int type = AttributeValues.getInt(message, KEY_TYPE, 0); - final long timestamp = AttributeValues.getLong(message, KEY_TIMESTAMP, 0L); - final String source = AttributeValues.getString(message, KEY_SOURCE, null); - final UUID sourceUuid = AttributeValues.getUUID(message, KEY_SOURCE_UUID, null); - final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0); - final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null); - final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null); - final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null); + private MessageProtos.Envelope convertItemToEnvelope(final Map item) { + final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray()); + final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray()); + final int type = AttributeValues.getInt(item, KEY_TYPE, 0); + final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L); + final String source = AttributeValues.getString(item, KEY_SOURCE, null); + final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null); + final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0); + final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null); + final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null); + final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null); + return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, - updatedPni, content, sortKey.getServerTimestamp()); + updatedPni, content, sortKey.getServerTimestamp()).toEnvelope(); } private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index a955f63da..587c519b6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -15,11 +15,10 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Pair; public class MessagesManager { @@ -61,10 +60,10 @@ public class MessagesManager { return messagesCache.hasMessages(destinationUuid, destinationDevice); } - public OutgoingMessageEntityList getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { + public Pair, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); - List messageList = new ArrayList<>(); + List messageList = new ArrayList<>(); if (!cachedMessagesOnly) { messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE)); @@ -74,7 +73,7 @@ public class MessagesManager { messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size())); } - return new OutgoingMessageEntityList(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE); + return new Pair<>(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE); } public void clear(UUID destinationUuid) { @@ -87,8 +86,8 @@ public class MessagesManager { messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId); } - public Optional delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) { - Optional removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid); + public Optional delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) { + Optional removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid); if (removed.isEmpty()) { if (serverTimestamp == null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index a6b0e06b7..2fa299a44 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; @@ -305,22 +306,25 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { try { - final OutgoingMessageEntityList messages = messagesManager + final Pair, Boolean> messagesAndHasMore = messagesManager .getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); - final CompletableFuture[] sendFutures = new CompletableFuture[messages.messages().size()]; + final List messages = messagesAndHasMore.first(); + final boolean hasMore = messagesAndHasMore.second(); - for (int i = 0; i < messages.messages().size(); i++) { - final OutgoingMessageEntity message = messages.messages().get(i); - final Envelope envelope = message.toEnvelope(); + final CompletableFuture[] sendFutures = new CompletableFuture[messages.size()]; + + for (int i = 0; i < messages.size(); i++) { + final Envelope envelope = messages.get(i); + final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { - messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.guid(), message.serverTimestamp()); + messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); discardedMessagesMeter.mark(); sendFutures[i] = CompletableFuture.completedFuture(null); } else { - sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(message.guid(), message.serverTimestamp()))); + sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(messageGuid, envelope.getServerTimestamp()))); } } @@ -329,7 +333,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac .orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS) .whenComplete((v, cause) -> { if (cause == null) { - if (messages.more()) { + if (hasMore) { sendNextMessagePage(cachedMessagesOnly, queueClearedFuture); } else { queueClearedFuture.complete(null); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 43b550406..8a8fb2dcf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -27,6 +27,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import com.google.common.collect.ImmutableSet; +import com.google.protobuf.ByteString; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; @@ -43,6 +44,7 @@ import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -57,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; +import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; @@ -77,6 +80,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -371,18 +375,22 @@ class MessageControllerTest { final long timestampTwo = 313388; final UUID messageGuidOne = UUID.randomUUID(); + final UUID messageGuidTwo = UUID.randomUUID(); final UUID sourceUuid = UUID.randomUUID(); final UUID updatedPniOne = UUID.randomUUID(); - List messages = new LinkedList<>() {{ - add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0)); - add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0)); - }}; + List messages = List.of( + generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0), + generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0) + ); - OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); + OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream() + .map(OutgoingMessageEntity::fromEnvelope) + .toList(), false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())) + .thenReturn(new Pair<>(messages, false)); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -397,7 +405,7 @@ class MessageControllerTest { assertEquals(response.messages().get(1).timestamp(), timestampTwo); assertEquals(response.messages().get(0).guid(), messageGuidOne); - assertNull(response.messages().get(1).guid()); + assertEquals(response.messages().get(1).guid(), messageGuidTwo); assertEquals(response.messages().get(0).sourceUuid(), sourceUuid); assertEquals(response.messages().get(1).sourceUuid(), sourceUuid); @@ -411,14 +419,13 @@ class MessageControllerTest { final long timestampOne = 313377; final long timestampTwo = 313388; - List messages = new LinkedList<>() {{ - add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0)); - add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)); - }}; + final List messages = List.of( + generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0), + generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0) + ); - OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())) + .thenReturn(new Pair<>(messages, false)); Response response = resources.getJerseyTest().target("/v1/messages/") @@ -437,12 +444,12 @@ class MessageControllerTest { UUID sourceUuid = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(new OutgoingMessageEntity( + when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(generateEnvelope( uuid1, Envelope.Type.CIPHERTEXT_VALUE, timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))); UUID uuid2 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(new OutgoingMessageEntity( + when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(generateEnvelope( uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))); @@ -624,4 +631,34 @@ class MessageControllerTest { Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false) ); } + + private static Envelope generateEnvelope(UUID guid, int type, long timestamp, String source, UUID sourceUuid, + int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { + + final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.forNumber(type)) + .setTimestamp(timestamp) + .setServerTimestamp(serverTimestamp) + .setDestinationUuid(destinationUuid.toString()) + .setServerGuid(guid.toString()); + + if (StringUtils.isNotEmpty(source)) { + builder.setSource(source) + .setSourceDevice(sourceDevice); + + if (sourceUuid != null) { + builder.setSourceUuid(sourceUuid.toString()); + } + } + + if (content != null) { + builder.setContent(ByteString.copyFrom(content)); + } + + if (updatedPni != null) { + builder.setUpdatedPni(updatedPni.toString()); + } + + return builder.build(); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index a7abfbb86..ea94c6772 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -34,7 +34,6 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; class MessagesCacheTest { @@ -103,11 +102,10 @@ class MessagesCacheTest { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, + final Optional maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid); - assertTrue(maybeRemovedMessage.isPresent()); - assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get()); + assertEquals(Optional.of(message), maybeRemovedMessage); } @ParameterizedTest @@ -135,14 +133,11 @@ class MessagesCacheTest { messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); } - final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, + final List removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) .collect(Collectors.toList())); - assertEquals(messagesToRemove.stream().map(MessagesCache::constructEntityFromEnvelope) - .collect(Collectors.toList()), - removedMessages); - + assertEquals(messagesToRemove, removedMessages); assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } @@ -163,14 +158,14 @@ class MessagesCacheTest { void testGetMessages(final boolean sealedSender) { final int messageCount = 100; - final List expectedMessages = new ArrayList<>(messageCount); + final List expectedMessages = new ArrayList<>(messageCount); for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); - final long messageId = messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); - expectedMessages.add(MessagesCache.constructEntityFromEnvelope(message)); + expectedMessages.add(message); } assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java index 7e162c757..596a5dec2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java @@ -83,15 +83,15 @@ class MessagesDynamoDbTest { final int destinationDeviceId = random.nextInt(255) + 1; messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); - final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, + final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); assertThat(messagesStored).isNotNull().hasSize(3); final MessageProtos.Envelope firstMessage = MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3; final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1; - assertThat(messagesStored).element(0).satisfies(verify(firstMessage)); - assertThat(messagesStored).element(1).satisfies(verify(secondMessage)); - assertThat(messagesStored).element(2).satisfies(verify(MESSAGE2)); + assertThat(messagesStored).element(0).isEqualTo(firstMessage); + assertThat(messagesStored).element(1).isEqualTo(secondMessage); + assertThat(messagesStored).element(2).isEqualTo(MESSAGE2); } @Test @@ -103,18 +103,18 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); } @Test @@ -126,19 +126,19 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); } @Test @@ -150,19 +150,19 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid, UUID.fromString(MESSAGE2.getServerGuid())); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } @@ -176,50 +176,20 @@ class MessagesDynamoDbTest { messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() - .hasSize(1).element(0).satisfies(verify(MESSAGE2)); + .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteMessage(secondDestinationUuid, 1, UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE1)); + .element(0).isEqualTo(MESSAGE1); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) - .element(0).satisfies(verify(MESSAGE3)); + .element(0).isEqualTo(MESSAGE3); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } - - private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) { - assertThat(retrieved.timestamp()).isEqualTo(inserted.getTimestamp()); - assertThat(retrieved.source()).isEqualTo(inserted.hasSource() ? inserted.getSource() : null); - assertThat(retrieved.sourceUuid()).isEqualTo(inserted.hasSourceUuid() ? UUID.fromString(inserted.getSourceUuid()) : null); - assertThat(retrieved.sourceDevice()).isEqualTo(inserted.getSourceDevice()); - assertThat(retrieved.type()).isEqualTo(inserted.getType().getNumber()); - assertThat(retrieved.content()).isEqualTo(inserted.hasContent() ? inserted.getContent().toByteArray() : null); - assertThat(retrieved.serverTimestamp()).isEqualTo(inserted.getServerTimestamp()); - assertThat(retrieved.guid()).isEqualTo(UUID.fromString(inserted.getServerGuid())); - assertThat(retrieved.destinationUuid()).isEqualTo(UUID.fromString(inserted.getDestinationUuid())); - } - - private static VerifyMessage verify(MessageProtos.Envelope expected) { - return new VerifyMessage(expected); - } - - private static final class VerifyMessage implements Consumer { - - private final MessageProtos.Envelope expected; - - public VerifyMessage(MessageProtos.Envelope expected) { - this.expected = expected; - } - - @Override - public void accept(OutgoingMessageEntity outgoingMessageEntity) { - verify(outgoingMessageEntity, expected); - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 9e4a36002..68c710eaa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -23,15 +23,17 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.auth.basic.BasicCredentials; import io.lettuce.core.RedisException; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -49,7 +51,6 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; @@ -111,14 +112,10 @@ class WebSocketConnectionTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.empty()); - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{ - put("login", new LinkedList<>() {{ - add(VALID_USER); - }}); - put("password", new LinkedList<>() {{ - add(VALID_PASSWORD); - }}); - }}); + + when(upgradeRequest.getParameterMap()).thenReturn(Map.of( + "login", List.of(VALID_USER), + "password", List.of(VALID_PASSWORD))); AuthenticationResult account = webSocketAuthenticator.authenticate(upgradeRequest); when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); @@ -127,14 +124,10 @@ class WebSocketConnectionTest { verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap>() {{ - put("login", new LinkedList() {{ - add(INVALID_USER); - }}); - put("password", new LinkedList() {{ - add(INVALID_PASSWORD); - }}); - }}); + when(upgradeRequest.getParameterMap()).thenReturn(Map.of( + "login", List.of(INVALID_USER), + "password", List.of(INVALID_PASSWORD) + )); account = webSocketAuthenticator.authenticate(upgradeRequest); assertFalse(account.getUser().isPresent()); @@ -149,13 +142,9 @@ class WebSocketConnectionTest { UUID senderOneUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID(); - List outgoingMessages = new LinkedList () {{ - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, "second")); - add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third")); - }}; - - OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); + List outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"), + createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, "second"), + createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third")); when(device.getId()).thenReturn(2L); @@ -175,7 +164,7 @@ class WebSocketConnectionTest { String userAgent = "user-agent"; when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) - .thenReturn(outgoingMessagesList); + .thenReturn(new Pair<>(outgoingMessages, false)); final List> futures = new LinkedList<>(); final WebSocketClient client = mock(WebSocketClient.class); @@ -207,7 +196,7 @@ class WebSocketConnectionTest { futures.get(0).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException()); - verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).guid()), eq(outgoingMessages.get(1).serverTimestamp())); + verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); verify(receiptSender, times(1)).sendReceipt(eq(auth), eq(senderOneUuid), eq(2222L)); connection.stop(); @@ -229,9 +218,9 @@ class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) - .thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first")), false)) - .thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")), false)); + .thenReturn(new Pair<>(Collections.emptyList(), false)) + .thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false)) + .thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -282,36 +271,27 @@ class WebSocketConnectionTest { final UUID senderTwoUuid = UUID.randomUUID(); final Envelope firstMessage = Envelope.newBuilder() - .setSource("sender1") - .setSourceUuid(UUID.randomUUID().toString()) - .setDestinationUuid(UUID.randomUUID().toString()) - .setUpdatedPni(UUID.randomUUID().toString()) - .setTimestamp(System.currentTimeMillis()) - .setSourceDevice(1) - .setType(Envelope.Type.CIPHERTEXT) - .build(); + .setServerGuid(UUID.randomUUID().toString()) + .setSource("sender1") + .setSourceUuid(UUID.randomUUID().toString()) + .setDestinationUuid(UUID.randomUUID().toString()) + .setUpdatedPni(UUID.randomUUID().toString()) + .setTimestamp(System.currentTimeMillis()) + .setSourceDevice(1) + .setType(Envelope.Type.CIPHERTEXT) + .build(); final Envelope secondMessage = Envelope.newBuilder() - .setSource("sender2") - .setSourceUuid(senderTwoUuid.toString()) - .setDestinationUuid(UUID.randomUUID().toString()) - .setTimestamp(System.currentTimeMillis()) - .setSourceDevice(2) - .setType(Envelope.Type.CIPHERTEXT) - .build(); + .setServerGuid(UUID.randomUUID().toString()) + .setSource("sender2") + .setSourceUuid(senderTwoUuid.toString()) + .setDestinationUuid(UUID.randomUUID().toString()) + .setTimestamp(System.currentTimeMillis()) + .setSourceDevice(2) + .setType(Envelope.Type.CIPHERTEXT) + .build(); - List pendingMessages = new LinkedList() {{ - add(new OutgoingMessageEntity(UUID.randomUUID(), firstMessage.getType().getNumber(), - firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()), - firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), UUID.fromString(firstMessage.getUpdatedPni()), - firstMessage.getContent().toByteArray(), 0)); - add(new OutgoingMessageEntity(UUID.randomUUID(), secondMessage.getType().getNumber(), - secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()), - secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), null, - secondMessage.getContent().toByteArray(), 0)); - }}; - - OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false); + final List pendingMessages = List.of(firstMessage, secondMessage); when(device.getId()).thenReturn(2L); @@ -331,20 +311,17 @@ class WebSocketConnectionTest { String userAgent = "user-agent"; when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) - .thenReturn(pendingMessagesList); + .thenReturn(new Pair<>(pendingMessages, false)); final List> futures = new LinkedList<>(); final WebSocketClient client = mock(WebSocketClient.class); when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any())) + .thenAnswer((Answer>) invocationOnMock -> { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, @@ -352,8 +329,7 @@ class WebSocketConnectionTest { connection.start(); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), - ArgumentMatchers.>any()); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()); assertEquals(futures.size(), 2); @@ -446,19 +422,16 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - final List firstPageMessages = - List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"), - createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")); + final List firstPageMessages = + List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), + createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); - final List secondPageMessages = - List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third")); - - final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); - final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + final List secondPageMessages = + List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)) - .thenReturn(firstPage) - .thenReturn(secondPage); + .thenReturn(new Pair<>(firstPageMessages, true)) + .thenReturn(new Pair<>(secondPageMessages, false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -493,11 +466,11 @@ class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); final UUID senderUuid = UUID.randomUUID(); - final List messages = List.of( - createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, false, "message the first")); - final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false); + final List messages = List.of( + createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, "message the first")); - when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)) + .thenReturn(new Pair<>(messages, false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -549,7 +522,7 @@ class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new Pair<>(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -577,20 +550,17 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - final List firstPageMessages = - List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"), - createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")); + final List firstPageMessages = + List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"), + createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second")); - final List secondPageMessages = - List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third")); - - final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); - final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + final List secondPageMessages = + List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(firstPage) - .thenReturn(secondPage) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new Pair<>(firstPageMessages, false)) + .thenReturn(new Pair<>(secondPageMessages, false)) + .thenReturn(new Pair<>(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -629,7 +599,7 @@ class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new Pair<>(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -662,7 +632,7 @@ class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new Pair<>(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -685,13 +655,11 @@ class WebSocketConnectionTest { UUID senderOneUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID(); - List outgoingMessages = new LinkedList () {{ - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); - add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third")); - }}; - - OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); + List outgoingMessages = List.of( + createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"), + createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, + RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)), + createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third")); when(device.getId()).thenReturn(2L); @@ -711,7 +679,7 @@ class WebSocketConnectionTest { String userAgent = "Signal-Desktop/1.2.3"; when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) - .thenReturn(outgoingMessagesList); + .thenReturn(new Pair<>(outgoingMessages, false)); final List> futures = new LinkedList<>(); final WebSocketClient client = mock(WebSocketClient.class); @@ -758,13 +726,10 @@ class WebSocketConnectionTest { UUID senderOneUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID(); - List outgoingMessages = new LinkedList () {{ - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); - add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); - add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third")); - }}; - - OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false); + List outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"), + createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, + RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)), + createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third")); when(device.getId()).thenReturn(2L); @@ -784,7 +749,7 @@ class WebSocketConnectionTest { String userAgent = "Signal-Android/4.68.3"; when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) - .thenReturn(outgoingMessagesList); + .thenReturn(new Pair<>(outgoingMessages, false)); final List> futures = new LinkedList<>(); final WebSocketClient client = mock(WebSocketClient.class); @@ -883,9 +848,18 @@ class WebSocketConnectionTest { verify(client, never()).close(anyInt(), anyString()); } - private OutgoingMessageEntity createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, boolean receipt, String content) { - return new OutgoingMessageEntity(UUID.randomUUID(), receipt ? Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, - timestamp, sender, senderUuid, 1, destinationUuid, null, content.getBytes(), 0); + private Envelope createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, String content) { + return Envelope.newBuilder() + .setServerGuid(UUID.randomUUID().toString()) + .setType(Envelope.Type.CIPHERTEXT) + .setTimestamp(timestamp) + .setServerTimestamp(0) + .setSource(sender) + .setSourceUuid(senderUuid.toString()) + .setSourceDevice(1) + .setDestinationUuid(destinationUuid.toString()) + .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8))) + .build(); } }