Make `Envelope` the main unit of currency when working with stored messages

This commit is contained in:
Jon Chambers 2022-07-27 15:43:39 -04:00 committed by Jon Chambers
parent 3e0919106d
commit 3636626e09
9 changed files with 245 additions and 278 deletions

View File

@ -21,7 +21,6 @@ import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -414,11 +413,19 @@ public class MessageController {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice())); RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice()));
} }
final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice( final OutgoingMessageEntityList outgoingMessages;
auth.getAccount().getUuid(), {
auth.getAuthenticatedDevice().getId(), final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice(
userAgent, auth.getAccount().getUuid(),
false); auth.getAuthenticatedDevice().getId(),
userAgent,
false);
outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream()
.map(OutgoingMessageEntity::fromEnvelope)
.collect(Collectors.toList()),
messagesAndHasMore.second());
}
{ {
String platform; String platform;
@ -450,24 +457,22 @@ public class MessageController {
@DELETE @DELETE
@Path("/uuid/{uuid}") @Path("/uuid/{uuid}")
public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) { public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
try { messagesManager.delete(
Optional<OutgoingMessageEntity> message = messagesManager.delete( auth.getAccount().getUuid(),
auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(),
auth.getAuthenticatedDevice().getId(), uuid,
uuid, null).ifPresent(deletedMessage -> {
null);
if (message.isPresent()) { WebSocketConnection.recordMessageDeliveryDuration(deletedMessage.getTimestamp(), auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(message.get().timestamp(), auth.getAuthenticatedDevice());
if (!Util.isEmpty(message.get().source()) if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
&& message.get().type() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { try {
receiptSender.sendReceipt(auth, message.get().sourceUuid(), message.get().timestamp()); receiptSender.sendReceipt(auth, UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
} }
} }
});
} catch (NoSuchUserException e) {
logger.warn("Sending delivery receipt", e);
}
} }
@Timed @Timed

View File

@ -40,7 +40,6 @@ import javax.annotation.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -148,13 +147,13 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
guid.toString().getBytes(StandardCharsets.UTF_8)))); guid.toString().getBytes(StandardCharsets.UTF_8))));
} }
public Optional<OutgoingMessageEntity> remove(final UUID destinationUuid, final long destinationDevice, public Optional<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
final UUID messageGuid) { final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst(); return remove(destinationUuid, destinationDevice, List.of(messageGuid)).stream().findFirst();
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> remove(final UUID destinationUuid, final long destinationDevice, public List<MessageProtos.Envelope> remove(final UUID destinationUuid, final long destinationDevice,
final List<UUID> messageGuids) { final List<UUID> messageGuids) {
final List<byte[]> serialized = (List<byte[]>) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, final List<byte[]> serialized = (List<byte[]>) Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG,
REMOVE_METHOD_UUID).record(() -> REMOVE_METHOD_UUID).record(() ->
@ -164,11 +163,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8)) messageGuids.stream().map(guid -> guid.toString().getBytes(StandardCharsets.UTF_8))
.collect(Collectors.toList()))); .collect(Collectors.toList())));
final List<OutgoingMessageEntity> removedMessages = new ArrayList<>(serialized.size()); final List<MessageProtos.Envelope> removedMessages = new ArrayList<>(serialized.size());
for (final byte[] bytes : serialized) { for (final byte[] bytes : serialized) {
try { try {
removedMessages.add(constructEntityFromEnvelope(MessageProtos.Envelope.parseFrom(bytes))); removedMessages.add(MessageProtos.Envelope.parseFrom(bytes));
} catch (final InvalidProtocolBufferException e) { } catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
@ -183,7 +182,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(final UUID destinationUuid, final long destinationDevice, final int limit) { public List<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDevice, final int limit) {
return getMessagesTimer.record(() -> { return getMessagesTimer.record(() -> {
final List<byte[]> queueItems = (List<byte[]>) getItemsScript.executeBinary( final List<byte[]> queueItems = (List<byte[]>) getItemsScript.executeBinary(
List.of(getMessageQueueKey(destinationUuid, destinationDevice), List.of(getMessageQueueKey(destinationUuid, destinationDevice),
@ -193,7 +192,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
final long earliestAllowableEphemeralTimestamp = final long earliestAllowableEphemeralTimestamp =
System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); System.currentTimeMillis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
final List<OutgoingMessageEntity> messageEntities; final List<MessageProtos.Envelope> messageEntities;
final List<UUID> staleEphemeralMessageGuids = new ArrayList<>(); final List<UUID> staleEphemeralMessageGuids = new ArrayList<>();
if (queueItems.size() % 2 == 0) { if (queueItems.size() % 2 == 0) {
@ -207,9 +206,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
continue; continue;
} }
final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8)); messageEntities.add(message);
messageEntities.add(constructEntityFromEnvelope(message));
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e); logger.warn("Failed to parse envelope", e);
} }
@ -379,21 +376,6 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
} }
} }
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
@VisibleForTesting @VisibleForTesting
static String getQueueName(final UUID accountUuid, final long deviceId) { static String getQueueName(final UUID accountUuid, final long deviceId) {
return accountUuid + "::" + deviceId; return accountUuid + "::" + deviceId;

View File

@ -112,7 +112,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems)); executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems));
} }
public List<OutgoingMessageEntity> load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) { public List<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final long destinationDeviceId, final int requestedNumberOfMessagesToFetch) {
return loadTimer.record(() -> { return loadTimer.record(() -> {
final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE); final int numberOfMessagesToFetch = Math.min(requestedNumberOfMessagesToFetch, RESULT_SET_CHUNK_SIZE);
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@ -128,9 +128,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId))) ":sortprefix", convertDestinationDeviceIdToSortKeyPrefix(destinationDeviceId)))
.limit(numberOfMessagesToFetch) .limit(numberOfMessagesToFetch)
.build(); .build();
List<OutgoingMessageEntity> messageEntities = new ArrayList<>(numberOfMessagesToFetch); List<MessageProtos.Envelope> messageEntities = new ArrayList<>(numberOfMessagesToFetch);
for (Map<String, AttributeValue> message : db().queryPaginator(queryRequest).items()) { for (Map<String, AttributeValue> message : db().queryPaginator(queryRequest).items()) {
messageEntities.add(convertItemToOutgoingMessageEntity(message)); messageEntities.add(convertItemToEnvelope(message));
if (messageEntities.size() == numberOfMessagesToFetch) { if (messageEntities.size() == numberOfMessagesToFetch) {
// queryPaginator() uses limit() as the page size, not as an absolute limit // queryPaginator() uses limit() as the page size, not as an absolute limit
// but a page might be smaller than limit, because a page is capped at 1 MB // but a page might be smaller than limit, because a page is capped at 1 MB
@ -141,7 +141,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}); });
} }
public Optional<OutgoingMessageEntity> deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid, public Optional<MessageProtos.Envelope> deleteMessageByDestinationAndGuid(final UUID destinationAccountUuid,
final UUID messageUuid) { final UUID messageUuid) {
return deleteByGuid.record(() -> { return deleteByGuid.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@ -162,7 +162,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}); });
} }
public Optional<OutgoingMessageEntity> deleteMessage(final UUID destinationAccountUuid, public Optional<MessageProtos.Envelope> deleteMessage(final UUID destinationAccountUuid,
final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
return deleteByKey.record(() -> { return deleteByKey.record(() -> {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@ -173,7 +173,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.returnValues(ReturnValue.ALL_OLD); .returnValues(ReturnValue.ALL_OLD);
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
return Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes())); return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
} }
return Optional.empty(); return Optional.empty();
@ -181,8 +181,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
} }
@Nonnull @Nonnull
private Optional<OutgoingMessageEntity> deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) { private Optional<MessageProtos.Envelope> deleteItemsMatchingQueryAndReturnFirstOneActuallyDeleted(AttributeValue partitionKey, QueryRequest queryRequest) {
Optional<OutgoingMessageEntity> result = Optional.empty(); Optional<MessageProtos.Envelope> result = Optional.empty();
for (Map<String, AttributeValue> item : db().queryPaginator(queryRequest).items()) { for (Map<String, AttributeValue> item : db().queryPaginator(queryRequest).items()) {
final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray(); final byte[] rangeKeyValue = item.get(KEY_SORT).b().asByteArray();
DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder() DeleteItemRequest.Builder deleteItemRequest = DeleteItemRequest.builder()
@ -193,7 +193,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
} }
final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build());
if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) {
result = Optional.of(convertItemToOutgoingMessageEntity(deleteItemResponse.attributes())); result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes()));
} }
} }
return result; return result;
@ -233,19 +233,20 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}); });
} }
private OutgoingMessageEntity convertItemToOutgoingMessageEntity(Map<String, AttributeValue> message) { private MessageProtos.Envelope convertItemToEnvelope(final Map<String, AttributeValue> item) {
final SortKey sortKey = convertSortKey(message.get(KEY_SORT).b().asByteArray()); final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray());
final UUID messageUuid = convertLocalIndexMessageUuidSortKey(message.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray()); final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray());
final int type = AttributeValues.getInt(message, KEY_TYPE, 0); final int type = AttributeValues.getInt(item, KEY_TYPE, 0);
final long timestamp = AttributeValues.getLong(message, KEY_TIMESTAMP, 0L); final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L);
final String source = AttributeValues.getString(message, KEY_SOURCE, null); final String source = AttributeValues.getString(item, KEY_SOURCE, null);
final UUID sourceUuid = AttributeValues.getUUID(message, KEY_SOURCE_UUID, null); final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null);
final int sourceDevice = AttributeValues.getInt(message, KEY_SOURCE_DEVICE, 0); final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(message, KEY_DESTINATION_UUID, null); final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(message, KEY_CONTENT, null); final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null);
final UUID updatedPni = AttributeValues.getUUID(message, KEY_UPDATED_PNI, null); final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null);
return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid,
updatedPni, content, sortKey.getServerTimestamp()); updatedPni, content, sortKey.getServerTimestamp()).toEnvelope();
} }
private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) { private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) {

View File

@ -15,11 +15,10 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
public class MessagesManager { public class MessagesManager {
@ -61,10 +60,10 @@ public class MessagesManager {
return messagesCache.hasMessages(destinationUuid, destinationDevice); return messagesCache.hasMessages(destinationUuid, destinationDevice);
} }
public OutgoingMessageEntityList getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));
List<OutgoingMessageEntity> messageList = new ArrayList<>(); List<Envelope> messageList = new ArrayList<>();
if (!cachedMessagesOnly) { if (!cachedMessagesOnly) {
messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE)); messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE));
@ -74,7 +73,7 @@ public class MessagesManager {
messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size())); messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size()));
} }
return new OutgoingMessageEntityList(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE); return new Pair<>(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE);
} }
public void clear(UUID destinationUuid) { public void clear(UUID destinationUuid) {
@ -87,8 +86,8 @@ public class MessagesManager {
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId);
} }
public Optional<OutgoingMessageEntity> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) { public Optional<Envelope> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, Long serverTimestamp) {
Optional<OutgoingMessageEntity> removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid); Optional<Envelope> removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid);
if (removed.isEmpty()) { if (removed.isEmpty()) {
if (serverTimestamp == null) { if (serverTimestamp == null) {

View File

@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil; import org.whispersystems.textsecuregcm.util.TimestampHeaderUtil;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@ -305,22 +306,25 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) { private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
try { try {
final OutgoingMessageEntityList messages = messagesManager final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager
.getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); .getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.messages().size()]; final List<Envelope> messages = messagesAndHasMore.first();
final boolean hasMore = messagesAndHasMore.second();
for (int i = 0; i < messages.messages().size(); i++) { final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.size()];
final OutgoingMessageEntity message = messages.messages().get(i);
final Envelope envelope = message.toEnvelope(); for (int i = 0; i < messages.size(); i++) {
final Envelope envelope = messages.get(i);
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.guid(), message.serverTimestamp()); messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp());
discardedMessagesMeter.mark(); discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null); sendFutures[i] = CompletableFuture.completedFuture(null);
} else { } else {
sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(message.guid(), message.serverTimestamp()))); sendFutures[i] = sendMessage(envelope, Optional.of(new StoredMessageInfo(messageGuid, envelope.getServerTimestamp())));
} }
} }
@ -329,7 +333,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
.orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS) .orTimeout(sendFuturesTimeoutMillis, TimeUnit.MILLISECONDS)
.whenComplete((v, cause) -> { .whenComplete((v, cause) -> {
if (cause == null) { if (cause == null) {
if (messages.more()) { if (hasMore) {
sendNextMessagePage(cachedMessagesOnly, queueClearedFuture); sendNextMessagePage(cachedMessagesOnly, queueClearedFuture);
} else { } else {
queueClearedFuture.complete(null); queueClearedFuture.complete(null);

View File

@ -27,6 +27,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
@ -43,6 +44,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -57,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
@ -77,6 +80,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -371,18 +375,22 @@ class MessageControllerTest {
final long timestampTwo = 313388; final long timestampTwo = 313388;
final UUID messageGuidOne = UUID.randomUUID(); final UUID messageGuidOne = UUID.randomUUID();
final UUID messageGuidTwo = UUID.randomUUID();
final UUID sourceUuid = UUID.randomUUID(); final UUID sourceUuid = UUID.randomUUID();
final UUID updatedPniOne = UUID.randomUUID(); final UUID updatedPniOne = UUID.randomUUID();
List<OutgoingMessageEntity> messages = new LinkedList<>() {{ List<Envelope> messages = List.of(
add(new OutgoingMessageEntity(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0)); generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0),
add(new OutgoingMessageEntity(null, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0)); generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", sourceUuid, 2, AuthHelper.VALID_UUID, null, null, 0)
}}; );
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream()
.map(OutgoingMessageEntity::fromEnvelope)
.toList(), false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
OutgoingMessageEntityList response = OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
@ -397,7 +405,7 @@ class MessageControllerTest {
assertEquals(response.messages().get(1).timestamp(), timestampTwo); assertEquals(response.messages().get(1).timestamp(), timestampTwo);
assertEquals(response.messages().get(0).guid(), messageGuidOne); assertEquals(response.messages().get(0).guid(), messageGuidOne);
assertNull(response.messages().get(1).guid()); assertEquals(response.messages().get(1).guid(), messageGuidTwo);
assertEquals(response.messages().get(0).sourceUuid(), sourceUuid); assertEquals(response.messages().get(0).sourceUuid(), sourceUuid);
assertEquals(response.messages().get(1).sourceUuid(), sourceUuid); assertEquals(response.messages().get(1).sourceUuid(), sourceUuid);
@ -411,14 +419,13 @@ class MessageControllerTest {
final long timestampOne = 313377; final long timestampOne = 313377;
final long timestampTwo = 313388; final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<>() {{ final List<Envelope> messages = List.of(
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0)); generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0),
add(new OutgoingMessageEntity(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)); generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, "+14152222222", UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)
}}; );
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
Response response = Response response =
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
@ -437,12 +444,12 @@ class MessageControllerTest {
UUID sourceUuid = UUID.randomUUID(); UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(new OutgoingMessageEntity( when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)).thenReturn(Optional.of(generateEnvelope(
uuid1, Envelope.Type.CIPHERTEXT_VALUE, uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))); timestamp, "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)));
UUID uuid2 = UUID.randomUUID(); UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(new OutgoingMessageEntity( when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)).thenReturn(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))); System.currentTimeMillis(), "+14152222222", sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)));
@ -624,4 +631,34 @@ class MessageControllerTest {
Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false) Arguments.of("fixtures/current_message_single_device_server_receipt_type.json", false)
); );
} }
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, String source, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString())
.setServerGuid(guid.toString());
if (StringUtils.isNotEmpty(source)) {
builder.setSource(source)
.setSourceDevice(sourceDevice);
if (sourceUuid != null) {
builder.setSourceUuid(sourceUuid.toString());
}
}
if (content != null) {
builder.setContent(ByteString.copyFrom(content));
}
if (updatedPni != null) {
builder.setUpdatedPni(updatedPni.toString());
}
return builder.build();
}
} }

View File

@ -34,7 +34,6 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheTest { class MessagesCacheTest {
@ -103,11 +102,10 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, final Optional<MessageProtos.Envelope> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid); DESTINATION_DEVICE_ID, messageGuid);
assertTrue(maybeRemovedMessage.isPresent()); assertEquals(Optional.of(message), maybeRemovedMessage);
assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get());
} }
@ParameterizedTest @ParameterizedTest
@ -135,14 +133,11 @@ class MessagesCacheTest {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
} }
final List<OutgoingMessageEntity> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, final List<MessageProtos.Envelope> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid())) messagesToRemove.stream().map(message -> UUID.fromString(message.getServerGuid()))
.collect(Collectors.toList())); .collect(Collectors.toList()));
assertEquals(messagesToRemove.stream().map(MessagesCache::constructEntityFromEnvelope) assertEquals(messagesToRemove, removedMessages);
.collect(Collectors.toList()),
removedMessages);
assertEquals(messagesToPreserve, assertEquals(messagesToPreserve,
messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
} }
@ -163,14 +158,14 @@ class MessagesCacheTest {
void testGetMessages(final boolean sealedSender) { void testGetMessages(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount); final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
expectedMessages.add(MessagesCache.constructEntityFromEnvelope(message)); expectedMessages.add(message);
} }
assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));

View File

@ -83,15 +83,15 @@ class MessagesDynamoDbTest {
final int destinationDeviceId = random.nextInt(255) + 1; final int destinationDeviceId = random.nextInt(255) + 1;
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<OutgoingMessageEntity> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, final List<MessageProtos.Envelope> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); MessagesDynamoDb.RESULT_SET_CHUNK_SIZE);
assertThat(messagesStored).isNotNull().hasSize(3); assertThat(messagesStored).isNotNull().hasSize(3);
final MessageProtos.Envelope firstMessage = final MessageProtos.Envelope firstMessage =
MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3; MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3;
final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1; final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1;
assertThat(messagesStored).element(0).satisfies(verify(firstMessage)); assertThat(messagesStored).element(0).isEqualTo(firstMessage);
assertThat(messagesStored).element(1).satisfies(verify(secondMessage)); assertThat(messagesStored).element(1).isEqualTo(secondMessage);
assertThat(messagesStored).element(2).satisfies(verify(MESSAGE2)); assertThat(messagesStored).element(2).isEqualTo(MESSAGE2);
} }
@Test @Test
@ -103,18 +103,18 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
} }
@Test @Test
@ -126,19 +126,19 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
} }
@Test @Test
@ -150,19 +150,19 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid, messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid,
UUID.fromString(MESSAGE2.getServerGuid())); UUID.fromString(MESSAGE2.getServerGuid()));
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty(); .isEmpty();
} }
@ -176,50 +176,20 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2)); .hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1, messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()); UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp());
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1)); .element(0).isEqualTo(MESSAGE1);
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3)); .element(0).isEqualTo(MESSAGE3);
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty(); .isEmpty();
} }
private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) {
assertThat(retrieved.timestamp()).isEqualTo(inserted.getTimestamp());
assertThat(retrieved.source()).isEqualTo(inserted.hasSource() ? inserted.getSource() : null);
assertThat(retrieved.sourceUuid()).isEqualTo(inserted.hasSourceUuid() ? UUID.fromString(inserted.getSourceUuid()) : null);
assertThat(retrieved.sourceDevice()).isEqualTo(inserted.getSourceDevice());
assertThat(retrieved.type()).isEqualTo(inserted.getType().getNumber());
assertThat(retrieved.content()).isEqualTo(inserted.hasContent() ? inserted.getContent().toByteArray() : null);
assertThat(retrieved.serverTimestamp()).isEqualTo(inserted.getServerTimestamp());
assertThat(retrieved.guid()).isEqualTo(UUID.fromString(inserted.getServerGuid()));
assertThat(retrieved.destinationUuid()).isEqualTo(UUID.fromString(inserted.getDestinationUuid()));
}
private static VerifyMessage verify(MessageProtos.Envelope expected) {
return new VerifyMessage(expected);
}
private static final class VerifyMessage implements Consumer<OutgoingMessageEntity> {
private final MessageProtos.Envelope expected;
public VerifyMessage(MessageProtos.Envelope expected) {
this.expected = expected;
}
@Override
public void accept(OutgoingMessageEntity outgoingMessageEntity) {
verify(outgoingMessageEntity, expected);
}
}
} }

View File

@ -23,15 +23,17 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -49,7 +51,6 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@ -111,14 +112,10 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty()); .thenReturn(Optional.empty());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{
put("login", new LinkedList<>() {{ when(upgradeRequest.getParameterMap()).thenReturn(Map.of(
add(VALID_USER); "login", List.of(VALID_USER),
}}); "password", List.of(VALID_PASSWORD)));
put("password", new LinkedList<>() {{
add(VALID_PASSWORD);
}});
}});
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest); AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
@ -127,14 +124,10 @@ class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{ when(upgradeRequest.getParameterMap()).thenReturn(Map.of(
put("login", new LinkedList<String>() {{ "login", List.of(INVALID_USER),
add(INVALID_USER); "password", List.of(INVALID_PASSWORD)
}}); ));
put("password", new LinkedList<String>() {{
add(INVALID_PASSWORD);
}});
}});
account = webSocketAuthenticator.authenticate(upgradeRequest); account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent()); assertFalse(account.getUser().isPresent());
@ -149,13 +142,9 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID(); UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{ List<Envelope> outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, "second"),
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, "second")); createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
@ -175,7 +164,7 @@ class WebSocketConnectionTest {
String userAgent = "user-agent"; String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList); .thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
@ -207,7 +196,7 @@ class WebSocketConnectionTest {
futures.get(0).completeExceptionally(new IOException()); futures.get(0).completeExceptionally(new IOException());
futures.get(2).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).guid()), eq(outgoingMessages.get(1).serverTimestamp())); verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq(senderOneUuid), eq(2222L)); verify(receiptSender, times(1)).sendReceipt(eq(auth), eq(senderOneUuid), eq(2222L));
connection.stop(); connection.stop();
@ -229,9 +218,9 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) .thenReturn(new Pair<>(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first")), false)) .thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")), false)); .thenReturn(new Pair<>(List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -282,36 +271,27 @@ class WebSocketConnectionTest {
final UUID senderTwoUuid = UUID.randomUUID(); final UUID senderTwoUuid = UUID.randomUUID();
final Envelope firstMessage = Envelope.newBuilder() final Envelope firstMessage = Envelope.newBuilder()
.setSource("sender1") .setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(UUID.randomUUID().toString()) .setSource("sender1")
.setDestinationUuid(UUID.randomUUID().toString()) .setSourceUuid(UUID.randomUUID().toString())
.setUpdatedPni(UUID.randomUUID().toString()) .setDestinationUuid(UUID.randomUUID().toString())
.setTimestamp(System.currentTimeMillis()) .setUpdatedPni(UUID.randomUUID().toString())
.setSourceDevice(1) .setTimestamp(System.currentTimeMillis())
.setType(Envelope.Type.CIPHERTEXT) .setSourceDevice(1)
.build(); .setType(Envelope.Type.CIPHERTEXT)
.build();
final Envelope secondMessage = Envelope.newBuilder() final Envelope secondMessage = Envelope.newBuilder()
.setSource("sender2") .setServerGuid(UUID.randomUUID().toString())
.setSourceUuid(senderTwoUuid.toString()) .setSource("sender2")
.setDestinationUuid(UUID.randomUUID().toString()) .setSourceUuid(senderTwoUuid.toString())
.setTimestamp(System.currentTimeMillis()) .setDestinationUuid(UUID.randomUUID().toString())
.setSourceDevice(2) .setTimestamp(System.currentTimeMillis())
.setType(Envelope.Type.CIPHERTEXT) .setSourceDevice(2)
.build(); .setType(Envelope.Type.CIPHERTEXT)
.build();
List<OutgoingMessageEntity> pendingMessages = new LinkedList<OutgoingMessageEntity>() {{ final List<Envelope> pendingMessages = List.of(firstMessage, secondMessage);
add(new OutgoingMessageEntity(UUID.randomUUID(), firstMessage.getType().getNumber(),
firstMessage.getTimestamp(), firstMessage.getSource(), UUID.fromString(firstMessage.getSourceUuid()),
firstMessage.getSourceDevice(), UUID.fromString(firstMessage.getDestinationUuid()), UUID.fromString(firstMessage.getUpdatedPni()),
firstMessage.getContent().toByteArray(), 0));
add(new OutgoingMessageEntity(UUID.randomUUID(), secondMessage.getType().getNumber(),
secondMessage.getTimestamp(), secondMessage.getSource(), UUID.fromString(secondMessage.getSourceUuid()),
secondMessage.getSourceDevice(), UUID.fromString(secondMessage.getDestinationUuid()), null,
secondMessage.getContent().toByteArray(), 0));
}};
OutgoingMessageEntityList pendingMessagesList = new OutgoingMessageEntityList(pendingMessages, false);
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
@ -331,20 +311,17 @@ class WebSocketConnectionTest {
String userAgent = "user-agent"; String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(pendingMessagesList); .thenReturn(new Pair<>(pendingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent); when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any())) when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() { .thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>) invocationOnMock -> {
@Override CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) { futures.add(future);
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>(); return future;
futures.add(future);
return future;
}
}); });
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
@ -352,8 +329,7 @@ class WebSocketConnectionTest {
connection.start(); connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any());
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(futures.size(), 2); assertEquals(futures.size(), 2);
@ -446,19 +422,16 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L); when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages = final List<Envelope> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"), List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")); createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second"));
final List<OutgoingMessageEntity> secondPageMessages = final List<Envelope> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third")); List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)) when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(firstPage) .thenReturn(new Pair<>(firstPageMessages, true))
.thenReturn(secondPage); .thenReturn(new Pair<>(secondPageMessages, false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -493,11 +466,11 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID(); final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of( final List<Envelope> messages = List.of(
createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, false, "message the first")); createMessage("senderE164", senderUuid, UUID.randomUUID(), 1111L, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
.thenReturn(new Pair<>(messages, false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -549,7 +522,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -577,20 +550,17 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L); when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages = final List<Envelope> firstPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, false, "first"), List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, false, "second")); createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 2222, "second"));
final List<OutgoingMessageEntity> secondPageMessages = final List<Envelope> secondPageMessages =
List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, false, "third")); List.of(createMessage("sender1", UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false);
final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false);
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(firstPage) .thenReturn(new Pair<>(firstPageMessages, false))
.thenReturn(secondPage) .thenReturn(new Pair<>(secondPageMessages, false))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -629,7 +599,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -662,7 +632,7 @@ class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -685,13 +655,11 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID(); UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{ List<Envelope> outgoingMessages = List.of(
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222,
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third")); RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
}}; createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
@ -711,7 +679,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Desktop/1.2.3"; String userAgent = "Signal-Desktop/1.2.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList); .thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
@ -758,13 +726,10 @@ class WebSocketConnectionTest {
UUID senderOneUuid = UUID.randomUUID(); UUID senderOneUuid = UUID.randomUUID();
UUID senderTwoUuid = UUID.randomUUID(); UUID senderTwoUuid = UUID.randomUUID();
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{ List<Envelope> outgoingMessages = List.of(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, "first"),
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 1111, false, "first")); createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222,
add(createMessage("sender1", senderOneUuid, UUID.randomUUID(), 2222, false, RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1))); RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
add(createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, false, "third")); createMessage("sender2", senderTwoUuid, UUID.randomUUID(), 3333, "third"));
}};
OutgoingMessageEntityList outgoingMessagesList = new OutgoingMessageEntityList(outgoingMessages, false);
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
@ -784,7 +749,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Android/4.68.3"; String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenReturn(outgoingMessagesList); .thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>(); final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
@ -883,9 +848,18 @@ class WebSocketConnectionTest {
verify(client, never()).close(anyInt(), anyString()); verify(client, never()).close(anyInt(), anyString());
} }
private OutgoingMessageEntity createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, boolean receipt, String content) { private Envelope createMessage(String sender, UUID senderUuid, UUID destinationUuid, long timestamp, String content) {
return new OutgoingMessageEntity(UUID.randomUUID(), receipt ? Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, return Envelope.newBuilder()
timestamp, sender, senderUuid, 1, destinationUuid, null, content.getBytes(), 0); .setServerGuid(UUID.randomUUID().toString())
.setType(Envelope.Type.CIPHERTEXT)
.setTimestamp(timestamp)
.setServerTimestamp(0)
.setSource(sender)
.setSourceUuid(senderUuid.toString())
.setSourceDevice(1)
.setDestinationUuid(destinationUuid.toString())
.setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
.build();
} }
} }