Migrate `MessagesDynamoDbRule` to `MessagesDynamoDbExtension`

This commit is contained in:
Chris Eager 2021-09-14 14:54:22 -07:00 committed by Jon Chambers
parent 6a5d475198
commit 83e0a19561
6 changed files with 370 additions and 322 deletions

View File

@ -25,6 +25,7 @@ import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType; import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput;
public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback { public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback {
@ -45,6 +46,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
private final List<AttributeDefinition> attributeDefinitions; private final List<AttributeDefinition> attributeDefinitions;
private final List<GlobalSecondaryIndex> globalSecondaryIndexes; private final List<GlobalSecondaryIndex> globalSecondaryIndexes;
private final List<LocalSecondaryIndex> localSecondaryIndexes;
private final long readCapacityUnits; private final long readCapacityUnits;
private final long writeCapacityUnits; private final long writeCapacityUnits;
@ -53,12 +55,16 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
private DynamoDbAsyncClient dynamoAsyncDB2; private DynamoDbAsyncClient dynamoAsyncDB2;
private AmazonDynamoDB legacyDynamoClient; private AmazonDynamoDB legacyDynamoClient;
private DynamoDbExtension(String tableName, String hashKey, String rangeKey, List<AttributeDefinition> attributeDefinitions, List<GlobalSecondaryIndex> globalSecondaryIndexes, long readCapacityUnits, private DynamoDbExtension(String tableName, String hashKey, String rangeKey,
List<AttributeDefinition> attributeDefinitions, List<GlobalSecondaryIndex> globalSecondaryIndexes,
final List<LocalSecondaryIndex> localSecondaryIndexes,
long readCapacityUnits,
long writeCapacityUnits) { long writeCapacityUnits) {
this.tableName = tableName; this.tableName = tableName;
this.hashKeyName = hashKey; this.hashKeyName = hashKey;
this.rangeKeyName = rangeKey; this.rangeKeyName = rangeKey;
this.localSecondaryIndexes = localSecondaryIndexes;
this.readCapacityUnits = readCapacityUnits; this.readCapacityUnits = readCapacityUnits;
this.writeCapacityUnits = writeCapacityUnits; this.writeCapacityUnits = writeCapacityUnits;
@ -108,6 +114,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
.keySchema(keySchemaElements) .keySchema(keySchemaElements)
.attributeDefinitions(attributeDefinitions.isEmpty() ? null : attributeDefinitions) .attributeDefinitions(attributeDefinitions.isEmpty() ? null : attributeDefinitions)
.globalSecondaryIndexes(globalSecondaryIndexes.isEmpty() ? null : globalSecondaryIndexes) .globalSecondaryIndexes(globalSecondaryIndexes.isEmpty() ? null : globalSecondaryIndexes)
.localSecondaryIndexes(localSecondaryIndexes.isEmpty() ? null : localSecondaryIndexes)
.provisionedThroughput(ProvisionedThroughput.builder() .provisionedThroughput(ProvisionedThroughput.builder()
.readCapacityUnits(readCapacityUnits) .readCapacityUnits(readCapacityUnits)
.writeCapacityUnits(writeCapacityUnits) .writeCapacityUnits(writeCapacityUnits)
@ -150,7 +157,8 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
.build(); .build();
} }
static class DynamoDbExtensionBuilder { public static class DynamoDbExtensionBuilder {
private String tableName = DEFAULT_TABLE_NAME; private String tableName = DEFAULT_TABLE_NAME;
private String hashKey; private String hashKey;
@ -158,6 +166,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
private List<AttributeDefinition> attributeDefinitions = new ArrayList<>(); private List<AttributeDefinition> attributeDefinitions = new ArrayList<>();
private List<GlobalSecondaryIndex> globalSecondaryIndexes = new ArrayList<>(); private List<GlobalSecondaryIndex> globalSecondaryIndexes = new ArrayList<>();
private List<LocalSecondaryIndex> localSecondaryIndexes = new ArrayList<>();
private long readCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.readCapacityUnits(); private long readCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.readCapacityUnits();
private long writeCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.writeCapacityUnits(); private long writeCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.writeCapacityUnits();
@ -166,22 +175,22 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
} }
DynamoDbExtensionBuilder tableName(String databaseName) { public DynamoDbExtensionBuilder tableName(String databaseName) {
this.tableName = databaseName; this.tableName = databaseName;
return this; return this;
} }
DynamoDbExtensionBuilder hashKey(String hashKey) { public DynamoDbExtensionBuilder hashKey(String hashKey) {
this.hashKey = hashKey; this.hashKey = hashKey;
return this; return this;
} }
DynamoDbExtensionBuilder rangeKey(String rangeKey) { public DynamoDbExtensionBuilder rangeKey(String rangeKey) {
this.rangeKey = rangeKey; this.rangeKey = rangeKey;
return this; return this;
} }
DynamoDbExtensionBuilder attributeDefinition(AttributeDefinition attributeDefinition) { public DynamoDbExtensionBuilder attributeDefinition(AttributeDefinition attributeDefinition) {
attributeDefinitions.add(attributeDefinition); attributeDefinitions.add(attributeDefinition);
return this; return this;
} }
@ -191,9 +200,14 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
return this; return this;
} }
DynamoDbExtension build() { public DynamoDbExtensionBuilder localSecondaryIndex(LocalSecondaryIndex index) {
localSecondaryIndexes.add(index);
return this;
}
public DynamoDbExtension build() {
return new DynamoDbExtension(tableName, hashKey, rangeKey, return new DynamoDbExtension(tableName, hashKey, rangeKey,
attributeDefinitions, globalSecondaryIndexes, readCapacityUnits, writeCapacityUnits); attributeDefinitions, globalSecondaryIndexes, localSecondaryIndexes, readCapacityUnits, writeCapacityUnits);
} }
} }

View File

@ -5,7 +5,8 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -24,158 +25,159 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.After; import org.junit.jupiter.api.AfterEach;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule; import org.junit.jupiter.api.Test;
import org.junit.Test; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
public class MessagePersisterIntegrationTest extends AbstractRedisClusterTest { class MessagePersisterIntegrationTest {
@Rule @RegisterExtension
public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule(); static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build();
private ExecutorService notificationExecutorService; @RegisterExtension
private MessagesCache messagesCache; static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private MessagesManager messagesManager;
private MessagePersister messagePersister;
private Account account;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(10); private ExecutorService notificationExecutorService;
private MessagesCache messagesCache;
private MessagesManager messagesManager;
private MessagePersister messagePersister;
private Account account;
@Before private static final Duration PERSIST_DELAY = Duration.ofMinutes(10);
@Override
public void setUp() throws Exception {
super.setUp();
getRedisCluster().useCluster(connection -> { @BeforeEach
connection.sync().flushall(); void setUp() throws Exception {
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
}); connection.sync().flushall();
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7)); final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
final AccountsManager accountsManager = mock(AccountsManager.class); MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14));
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); final AccountsManager accountsManager = mock(AccountsManager.class);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
notificationExecutorService = Executors.newSingleThreadExecutor(); notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), mock(ReportMessageManager.class)); REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class),
mock(ReportMessageManager.class));
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);
account = mock(Account.class); account = mock(Account.class);
final UUID accountUuid = UUID.randomUUID(); final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account)); when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account));
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
messagesCache.start(); messagesCache.start();
} }
@After @AfterEach
@Override void tearDown() throws Exception {
public void tearDown() throws Exception { notificationExecutorService.shutdown();
super.tearDown(); notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS);
}
notificationExecutorService.shutdown(); @Test
notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS); void testScheduledPersistMessages() {
}
@Test(timeout = 15_000) final int messageCount = 377;
public void testScheduledPersistMessages() throws Exception { final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
final int messageCount = 377; final Instant now = Instant.now();
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
final Instant now = Instant.now();
for (int i = 0; i < messageCount; i++) { assertTimeout(Duration.ofSeconds(15), () -> {
final UUID messageGuid = UUID.randomUUID();
final long timestamp = now.minus(PERSIST_DELAY.multipliedBy(2)).toEpochMilli() + i;
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final long timestamp = now.minus(PERSIST_DELAY.multipliedBy(2)).toEpochMilli() + i;
messagesCache.insert(messageGuid, account.getUuid(), 1, message); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
expectedMessages.add(message);
messagesCache.insert(messageGuid, account.getUuid(), 1, message);
expectedMessages.add(message);
}
REDIS_CLUSTER_EXTENSION.getRedisCluster()
.useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY,
String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1)));
final AtomicBoolean messagesPersisted = new AtomicBoolean(false);
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
@Override
public void handleNewMessagesAvailable() {
} }
getRedisCluster().useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1))); @Override
public void handleNewEphemeralMessageAvailable() {
final AtomicBoolean messagesPersisted = new AtomicBoolean(false);
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
@Override
public void handleNewMessagesAvailable() {
}
@Override
public void handleNewEphemeralMessageAvailable() {
}
@Override
public void handleMessagesPersisted() {
synchronized (messagesPersisted) {
messagesPersisted.set(true);
messagesPersisted.notifyAll();
}
}
});
messagePersister.start();
synchronized (messagesPersisted) {
while (!messagesPersisted.get()) {
messagesPersisted.wait();
}
} }
messagePersister.stop(); @Override
public void handleMessagesPersisted() {
synchronized (messagesPersisted) {
messagesPersisted.set(true);
messagesPersisted.notifyAll();
}
}
});
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(messageCount); messagePersister.start();
DynamoDbClient dynamoDB = messagesDynamoDbRule.getDynamoDbClient(); synchronized (messagesPersisted) {
while (!messagesPersisted.get()) {
messagesPersisted.wait();
}
}
messagePersister.stop();
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(messageCount);
DynamoDbClient dynamoDB = dynamoDbExtension.getDynamoDbClient();
for (Map<String, AttributeValue> item : dynamoDB for (Map<String, AttributeValue> item : dynamoDB
.scan(ScanRequest.builder().tableName(MessagesDynamoDbRule.TABLE_NAME).build()).items()) { .scan(ScanRequest.builder().tableName(MessagesDynamoDbExtension.TABLE_NAME).build()).items()) {
persistedMessages.add(MessageProtos.Envelope.newBuilder() persistedMessages.add(MessageProtos.Envelope.newBuilder()
.setServerGuid(AttributeValues.getUUID(item, "U", null).toString()) .setServerGuid(AttributeValues.getUUID(item, "U", null).toString())
.setType(MessageProtos.Envelope.Type.valueOf(AttributeValues.getInt(item, "T", -1))) .setType(Type.forNumber(AttributeValues.getInt(item, "T", -1)))
.setTimestamp(AttributeValues.getLong(item, "TS", -1)) .setTimestamp(AttributeValues.getLong(item, "TS", -1))
.setServerTimestamp(extractServerTimestamp(AttributeValues.getByteArray(item, "S", null))) .setServerTimestamp(extractServerTimestamp(AttributeValues.getByteArray(item, "S", null)))
.setContent(ByteString.copyFrom(AttributeValues.getByteArray(item, "C", null))) .setContent(ByteString.copyFrom(AttributeValues.getByteArray(item, "C", null)))
.build()); .build());
} }
assertEquals(expectedMessages, persistedMessages); assertEquals(expectedMessages, persistedMessages);
} });
}
private static UUID convertBinaryToUuid(byte[] bytes) { private static long extractServerTimestamp(byte[] bytes) {
ByteBuffer bb = ByteBuffer.wrap(bytes); ByteBuffer bb = ByteBuffer.wrap(bytes);
long msb = bb.getLong(); bb.getLong();
long lsb = bb.getLong(); return bb.getLong();
return new UUID(msb, lsb); }
}
private static long extractServerTimestamp(byte[] bytes) { private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long timestamp) {
ByteBuffer bb = ByteBuffer.wrap(bytes); return MessageProtos.Envelope.newBuilder()
bb.getLong(); .setTimestamp(timestamp)
return bb.getLong(); .setServerTimestamp(timestamp)
} .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long timestamp) { .setServerGuid(messageGuid.toString())
return MessageProtos.Envelope.newBuilder() .build();
.setTimestamp(timestamp) }
.setServerTimestamp(timestamp)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
}
} }

View File

@ -13,15 +13,18 @@ import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.ClassRule; import org.junit.jupiter.api.Test;
import org.junit.Test; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
class MessagesDynamoDbTest {
public class MessagesDynamoDbTest {
private static final Random random = new Random(); private static final Random random = new Random();
private static final MessageProtos.Envelope MESSAGE1; private static final MessageProtos.Envelope MESSAGE1;
private static final MessageProtos.Envelope MESSAGE2; private static final MessageProtos.Envelope MESSAGE2;
@ -61,27 +64,31 @@ public class MessagesDynamoDbTest {
private MessagesDynamoDb messagesDynamoDb; private MessagesDynamoDb messagesDynamoDb;
@ClassRule
public static MessagesDynamoDbRule dynamoDbRule = new MessagesDynamoDbRule();
@Before @RegisterExtension
public void setup() { static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build();
messagesDynamoDb = new MessagesDynamoDb(dynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7));
@BeforeEach
void setup() {
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
Duration.ofDays(14));
} }
@Test @Test
public void testServerStart() { void testServerStart() {
} }
@Test @Test
public void testSimpleFetchAfterInsert() { void testSimpleFetchAfterInsert() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1; final int destinationDeviceId = random.nextInt(255) + 1;
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<OutgoingMessageEntity> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); final List<OutgoingMessageEntity> messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
MessagesDynamoDb.RESULT_SET_CHUNK_SIZE);
assertThat(messagesStored).isNotNull().hasSize(3); assertThat(messagesStored).isNotNull().hasSize(3);
final MessageProtos.Envelope firstMessage = MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3; final MessageProtos.Envelope firstMessage =
MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3;
final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1; final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1;
assertThat(messagesStored).element(0).satisfies(verify(firstMessage)); assertThat(messagesStored).element(0).satisfies(verify(firstMessage));
assertThat(messagesStored).element(1).satisfies(verify(secondMessage)); assertThat(messagesStored).element(1).satisfies(verify(secondMessage));
@ -89,61 +96,76 @@ public class MessagesDynamoDbTest {
} }
@Test @Test
public void testDeleteForDestination() { void testDeleteForDestination() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); .element(0).satisfies(verify(MESSAGE1));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
} }
@Test @Test
public void testDeleteForDestinationDevice() { void testDeleteForDestinationDevice() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); .element(0).satisfies(verify(MESSAGE1));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE1));
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
} }
@Test @Test
public void testDeleteMessageByDestinationAndGuid() { void testDeleteMessageByDestinationAndGuid() {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); .element(0).satisfies(verify(MESSAGE1));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).satisfies(verify(MESSAGE2));
messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid, messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid,
UUID.fromString(MESSAGE2.getServerGuid())); UUID.fromString(MESSAGE2.getServerGuid()));
assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); .element(0).satisfies(verify(MESSAGE1));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).satisfies(verify(MESSAGE3));
assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
} }
private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) { private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) {
@ -164,6 +186,7 @@ public class MessagesDynamoDbTest {
} }
private static final class VerifyMessage implements Consumer<OutgoingMessageEntity> { private static final class VerifyMessage implements Consumer<OutgoingMessageEntity> {
private final MessageProtos.Envelope expected; private final MessageProtos.Envelope expected;
public VerifyMessage(MessageProtos.Envelope expected) { public VerifyMessage(MessageProtos.Envelope expected) {

View File

@ -5,42 +5,36 @@
package org.whispersystems.textsecuregcm.tests.util; package org.whispersystems.textsecuregcm.tests.util;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType; import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.Projection; import software.amazon.awssdk.services.dynamodb.model.Projection;
import software.amazon.awssdk.services.dynamodb.model.ProjectionType; import software.amazon.awssdk.services.dynamodb.model.ProjectionType;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
public class MessagesDynamoDbRule extends LocalDynamoDbRule { public class MessagesDynamoDbExtension {
public static final String TABLE_NAME = "Signal_Messages_UnitTest"; public static final String TABLE_NAME = "Signal_Messages_UnitTest";
@Override public static DynamoDbExtension build() {
protected void before() throws Throwable { return DynamoDbExtension.builder()
super.before();
getDynamoDbClient().createTable(CreateTableRequest.builder()
.tableName(TABLE_NAME) .tableName(TABLE_NAME)
.keySchema(KeySchemaElement.builder().attributeName("H").keyType(KeyType.HASH).build(), .hashKey("H")
KeySchemaElement.builder().attributeName("S").keyType(KeyType.RANGE).build()) .rangeKey("S")
.attributeDefinitions( .attributeDefinition(
AttributeDefinition.builder().attributeName("H").attributeType(ScalarAttributeType.B).build(), AttributeDefinition.builder().attributeName("H").attributeType(ScalarAttributeType.B).build())
AttributeDefinition.builder().attributeName("S").attributeType(ScalarAttributeType.B).build(), .attributeDefinition(
AttributeDefinition.builder().attributeName("S").attributeType(ScalarAttributeType.B).build())
.attributeDefinition(
AttributeDefinition.builder().attributeName("U").attributeType(ScalarAttributeType.B).build()) AttributeDefinition.builder().attributeName("U").attributeType(ScalarAttributeType.B).build())
.provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(20L).writeCapacityUnits(20L).build()) .localSecondaryIndex(LocalSecondaryIndex.builder().indexName("Message_UUID_Index")
.localSecondaryIndexes(LocalSecondaryIndex.builder().indexName("Message_UUID_Index")
.keySchema(KeySchemaElement.builder().attributeName("H").keyType(KeyType.HASH).build(), .keySchema(KeySchemaElement.builder().attributeName("H").keyType(KeyType.HASH).build(),
KeySchemaElement.builder().attributeName("U").keyType(KeyType.RANGE).build()) KeySchemaElement.builder().attributeName("U").keyType(KeyType.RANGE).build())
.projection(Projection.builder().projectionType(ProjectionType.KEYS_ONLY).build()) .projection(Projection.builder().projectionType(ProjectionType.KEYS_ONLY).build())
.build()) .build())
.build()); .build();
} }
@Override
protected void after() {
super.after();
}
} }

View File

@ -5,9 +5,10 @@
package org.whispersystems.textsecuregcm.websocket; package org.whispersystems.textsecuregcm.websocket;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -34,10 +35,10 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.After; import org.junit.jupiter.api.AfterEach;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Rule; import org.junit.jupiter.api.Test;
import org.junit.Test; import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -45,195 +46,208 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest { class WebSocketConnectionIntegrationTest {
@Rule @RegisterExtension
public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule(); static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build();
private ExecutorService executorService; @RegisterExtension
private MessagesDynamoDb messagesDynamoDb; static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
private long serialTimestamp = System.currentTimeMillis(); private ExecutorService executorService;
private MessagesDynamoDb messagesDynamoDb;
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
@Before private long serialTimestamp = System.currentTimeMillis();
@Override
public void setUp() throws Exception {
super.setUp();
executorService = Executors.newSingleThreadExecutor(); @BeforeEach
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), executorService); void setUp() throws Exception {
messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7));
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
webSocketClient = mock(WebSocketClient.class);
retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor();
when(account.getNumber()).thenReturn("+18005551234"); executorService = Executors.newSingleThreadExecutor();
when(account.getUuid()).thenReturn(UUID.randomUUID()); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
when(device.getId()).thenReturn(1L); REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService);
messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
Duration.ofDays(7));
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
webSocketClient = mock(WebSocketClient.class);
retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor();
webSocketConnection = new WebSocketConnection( when(account.getNumber()).thenReturn("+18005551234");
mock(ReceiptSender.class), when(account.getUuid()).thenReturn(UUID.randomUUID());
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), when(device.getId()).thenReturn(1L);
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor);
}
@After webSocketConnection = new WebSocketConnection(
@Override mock(ReceiptSender.class),
public void tearDown() throws Exception { new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
executorService.shutdown(); new AuthenticatedAccount(() -> new Pair<>(account, device)),
executorService.awaitTermination(2, TimeUnit.SECONDS); device,
webSocketClient,
retrySchedulingExecutor);
}
retrySchedulingExecutor.shutdown(); @AfterEach
retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS); void tearDown() throws Exception {
executorService.shutdown();
executorService.awaitTermination(2, TimeUnit.SECONDS);
super.tearDown(); retrySchedulingExecutor.shutdown();
} retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS);
}
@Test(timeout = 15_000) @Test
public void testProcessStoredMessages() throws InterruptedException { void testProcessStoredMessages() {
final int persistedMessageCount = 207; final int persistedMessageCount = 207;
final int cachedMessageCount = 173; final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
{ assertTimeout(Duration.ofSeconds(15), () -> {
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) { {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
persistedMessages.add(envelope); for (int i = 0; i < persistedMessageCount; i++) {
expectedMessages.add(envelope); final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); persistedMessages.add(envelope);
expectedMessages.add(envelope);
} }
for (int i = 0; i < cachedMessageCount; i++) { messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
final UUID messageGuid = UUID.randomUUID(); }
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); for (int i = 0; i < cachedMessageCount; i++) {
expectedMessages.add(envelope); final UUID messageGuid = UUID.randomUUID();
} final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
final AtomicBoolean queueCleared = new AtomicBoolean(false); expectedMessages.add(envelope);
}
when(successResponse.getStatus()).thenReturn(200); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); final AtomicBoolean queueCleared = new AtomicBoolean(false);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer((Answer<CompletableFuture<WebSocketResponseMessage>>)invocation -> { when(successResponse.getStatus()).thenReturn(200);
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer(
(Answer<CompletableFuture<WebSocketResponseMessage>>) invocation -> {
synchronized (queueCleared) { synchronized (queueCleared) {
queueCleared.set(true); queueCleared.set(true);
queueCleared.notifyAll(); queueCleared.notifyAll();
} }
return CompletableFuture.completedFuture(successResponse); return CompletableFuture.completedFuture(successResponse);
});
webSocketConnection.processStoredMessages();
synchronized (queueCleared) {
while (!queueCleared.get()) {
queueCleared.wait();
}
}
@SuppressWarnings("unchecked") final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(
Optional.class);
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>();
for (final Optional<byte[]> maybeMessageBody : messageBodyCaptor.getAllValues()) {
maybeMessageBody.ifPresent(messageBytes -> {
try {
sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes));
} catch (final InvalidProtocolBufferException e) {
fail("Could not parse sent message");
}
}); });
}
webSocketConnection.processStoredMessages(); assertEquals(expectedMessages, sentMessages);
});
}
synchronized (queueCleared) { @Test
while (!queueCleared.get()) { void testProcessStoredMessagesClientClosed() {
queueCleared.wait(); final int persistedMessageCount = 207;
} final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
assertTimeout(Duration.ofSeconds(15), () -> {
{
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount);
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
} }
@SuppressWarnings("unchecked") messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
final ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); }
verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); for (int i = 0; i < cachedMessageCount; i++) {
verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
final List<MessageProtos.Envelope> sentMessages = new ArrayList<>(); expectedMessages.add(envelope);
}
for (final Optional<byte[]> maybeMessageBody : messageBodyCaptor.getAllValues()) { when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
maybeMessageBody.ifPresent(messageBytes -> { CompletableFuture.failedFuture(new IOException("Connection closed")));
try {
sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes));
} catch (final InvalidProtocolBufferException e) {
fail("Could not parse sent message");
}
});
}
assertEquals(expectedMessages, sentMessages); webSocketConnection.processStoredMessages();
}
@Test(timeout = 15_000) //noinspection unchecked
public void testProcessStoredMessagesClientClosed() { ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"),
eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
{ verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(),
final List<MessageProtos.Envelope> persistedMessages = new ArrayList<>(persistedMessageCount); eq(Optional.empty()));
for (int i = 0; i < persistedMessageCount; i++) {
final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID());
persistedMessages.add(envelope);
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId());
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
expectedMessages.add(envelope);
}
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed")));
webSocketConnection.processStoredMessages();
//noinspection unchecked
ArgumentCaptor<Optional<byte[]>> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class);
verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture());
verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty()));
final List<MessageProtos.Envelope> sentMessages = messageBodyCaptor.getAllValues().stream() final List<MessageProtos.Envelope> sentMessages = messageBodyCaptor.getAllValues().stream()
.map(Optional::get) .map(Optional::get)
.map(messageBytes -> { .map(messageBytes -> {
try { try {
return Envelope.parseFrom(messageBytes); return Envelope.parseFrom(messageBytes);
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
}) })
.collect(Collectors.toList()); .collect(Collectors.toList());
assertTrue(expectedMessages.containsAll(sentMessages)); assertTrue(expectedMessages.containsAll(sentMessages));
});
} }
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) {

View File

@ -65,18 +65,19 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
} }
} }
return new WebSocketResourceProvider<T>(getRemoteAddress(request), return new WebSocketResourceProvider<>(getRemoteAddress(request),
this.jerseyApplicationHandler, this.jerseyApplicationHandler,
this.environment.getRequestLog(), this.environment.getRequestLog(),
authenticated, authenticated,
this.environment.getMessageFactory(), this.environment.getMessageFactory(),
ofNullable(this.environment.getConnectListener()), ofNullable(this.environment.getConnectListener()),
this.environment.getIdleTimeoutMillis()); this.environment.getIdleTimeoutMillis());
} catch (AuthenticationException | IOException e) { } catch (AuthenticationException | IOException e) {
logger.warn("Authentication failure", e); logger.warn("Authentication failure", e);
try { try {
response.sendError(500, "Failure"); response.sendError(500, "Failure");
} catch (IOException ex) {} } catch (IOException ignored) {
}
return null; return null;
} }
} }