diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index fcf7fe1ac..5195ff688 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -8,7 +8,11 @@ package org.whispersystems.textsecuregcm.storage; import static com.codahale.metrics.MetricRegistry.name; import static io.micrometer.core.instrument.Metrics.timer; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.InvalidProtocolBufferException; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.nio.ByteBuffer; import java.time.Duration; @@ -19,6 +23,8 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; import javax.annotation.Nonnull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -35,8 +41,12 @@ import software.amazon.awssdk.services.dynamodb.model.WriteRequest; public class MessagesDynamoDb extends AbstractDynamoDbStore { - private static final String KEY_PARTITION = "H"; - private static final String KEY_SORT = "S"; + @VisibleForTesting + static final String KEY_PARTITION = "H"; + + @VisibleForTesting + static final String KEY_SORT = "S"; + private static final String LOCAL_INDEX_MESSAGE_UUID_NAME = "Message_UUID_Index"; private static final String LOCAL_INDEX_MESSAGE_UUID_KEY_SORT = "U"; @@ -50,6 +60,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { private static final String KEY_CONTENT = "C"; private static final String KEY_TTL = "E"; + @VisibleForTesting + static final String KEY_ENVELOPE_BYTES = "EB"; + private final Timer storeTimer = timer(name(getClass(), "store")); private final Timer loadTimer = timer(name(getClass(), "load")); private final Timer deleteByGuid = timer(name(getClass(), "delete", "guid")); @@ -60,6 +73,11 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { private final String tableName; private final Duration timeToLive; + private static final Counter GET_MESSAGE_WITH_ATTRIBUTES_COUNTER = Metrics.counter(name(MessagesDynamoDb.class, "loadMessage"), "format", "attributes"); + private static final Counter GET_MESSAGE_WITH_ENVELOPE_COUNTER = Metrics.counter(name(MessagesDynamoDb.class, "loadMessage"), "format", "envelope"); + + private static final Logger logger = LoggerFactory.getLogger(MessagesDynamoDb.class); + public MessagesDynamoDb(DynamoDbClient dynamoDb, String tableName, Duration timeToLive) { super(dynamoDb); @@ -130,7 +148,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .build(); List messageEntities = new ArrayList<>(numberOfMessagesToFetch); for (Map message : db().queryPaginator(queryRequest).items()) { - messageEntities.add(convertItemToEnvelope(message)); + try { + messageEntities.add(convertItemToEnvelope(message)); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + } + if (messageEntities.size() == numberOfMessagesToFetch) { // queryPaginator() uses limit() as the page size, not as an absolute limit // …but a page might be smaller than limit, because a page is capped at 1 MB @@ -173,7 +196,12 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .returnValues(ReturnValue.ALL_OLD); final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); + try { + return Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + return Optional.empty(); + } } return Optional.empty(); @@ -193,7 +221,11 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } final DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest.build()); if (deleteItemResponse.attributes() != null && deleteItemResponse.attributes().containsKey(KEY_PARTITION)) { - result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); + try { + result = Optional.of(convertItemToEnvelope(deleteItemResponse.attributes())); + } catch (final InvalidProtocolBufferException e) { + logger.error("Failed to parse envelope", e); + } } } return result; @@ -233,20 +265,33 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { }); } - private MessageProtos.Envelope convertItemToEnvelope(final Map item) { - final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray()); - final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray()); - final int type = AttributeValues.getInt(item, KEY_TYPE, 0); - final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L); - final String source = AttributeValues.getString(item, KEY_SOURCE, null); - final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null); - final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0); - final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null); - final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null); - final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null); + private MessageProtos.Envelope convertItemToEnvelope(final Map item) + throws InvalidProtocolBufferException { + final MessageProtos.Envelope envelope; - return new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, - updatedPni, content, sortKey.getServerTimestamp()).toEnvelope(); + if (item.containsKey(KEY_ENVELOPE_BYTES)) { + envelope = MessageProtos.Envelope.parseFrom(item.get(KEY_ENVELOPE_BYTES).b().asByteArray()); + + GET_MESSAGE_WITH_ENVELOPE_COUNTER.increment(); + } else { + final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray()); + final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray()); + final int type = AttributeValues.getInt(item, KEY_TYPE, 0); + final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L); + final String source = AttributeValues.getString(item, KEY_SOURCE, null); + final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null); + final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0); + final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null); + final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null); + final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null); + + envelope = new OutgoingMessageEntity(messageUuid, type, timestamp, source, sourceUuid, sourceDevice, destinationUuid, + updatedPni, content, sortKey.getServerTimestamp()).toEnvelope(); + + GET_MESSAGE_WITH_ATTRIBUTES_COUNTER.increment(); + } + + return envelope; } private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) { @@ -268,11 +313,13 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { return message.getServerTimestamp() / 1000 + timeToLive.getSeconds(); } - private static AttributeValue convertPartitionKey(final UUID destinationAccountUuid) { + @VisibleForTesting + static AttributeValue convertPartitionKey(final UUID destinationAccountUuid) { return AttributeValues.fromUUID(destinationAccountUuid); } - private static AttributeValue convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { + @VisibleForTesting + static AttributeValue convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); byteBuffer.putLong(destinationDeviceId); byteBuffer.putLong(serverTimestamp); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java similarity index 85% rename from service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java index 596a5dec2..b620dc080 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java @@ -1,26 +1,26 @@ /* - * Copyright 2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ -package org.whispersystems.textsecuregcm.tests.storage; +package org.whispersystems.textsecuregcm.storage; import static org.assertj.core.api.Assertions.assertThat; import com.google.protobuf.ByteString; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.UUID; -import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; -import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; class MessagesDynamoDbTest { @@ -94,6 +94,31 @@ class MessagesDynamoDbTest { assertThat(messagesStored).element(2).isEqualTo(MESSAGE2); } + @Test + void testFetchBareEnvelope() { + final UUID destinationUuid = UUID.randomUUID(); + final long destinationDeviceId = Device.MASTER_ID; + final long serverTimestamp = System.currentTimeMillis(); + final UUID messageGuid = UUID.randomUUID(); + + final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() + .setServerGuid(messageGuid.toString()) + .setDestinationUuid(destinationUuid.toString()) + .setServerTimestamp(serverTimestamp) + .build(); + + dynamoDbExtension.getDynamoDbClient().putItem(PutItemRequest.builder() + .tableName(dynamoDbExtension.getTableName()) + .item(Map.of( + MessagesDynamoDb.KEY_PARTITION, MessagesDynamoDb.convertPartitionKey(destinationUuid), + MessagesDynamoDb.KEY_SORT, MessagesDynamoDb.convertSortKey(destinationDeviceId, serverTimestamp, messageGuid), + MessagesDynamoDb.KEY_ENVELOPE_BYTES, AttributeValue.builder().b(SdkBytes.fromByteArray(envelope.toByteArray())).build())) + .build()); + + assertThat(messagesDynamoDb.load(destinationUuid, destinationDeviceId, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)) + .isEqualTo(List.of(envelope)); + } + @Test void testDeleteForDestination() { final UUID destinationUuid = UUID.randomUUID();