diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java index 75b4f86e7..38cad7696 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtension.java @@ -25,6 +25,7 @@ import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback { @@ -45,6 +46,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback private final List attributeDefinitions; private final List globalSecondaryIndexes; + private final List localSecondaryIndexes; private final long readCapacityUnits; private final long writeCapacityUnits; @@ -53,12 +55,16 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback private DynamoDbAsyncClient dynamoAsyncDB2; private AmazonDynamoDB legacyDynamoClient; - private DynamoDbExtension(String tableName, String hashKey, String rangeKey, List attributeDefinitions, List globalSecondaryIndexes, long readCapacityUnits, + private DynamoDbExtension(String tableName, String hashKey, String rangeKey, + List attributeDefinitions, List globalSecondaryIndexes, + final List localSecondaryIndexes, + long readCapacityUnits, long writeCapacityUnits) { this.tableName = tableName; this.hashKeyName = hashKey; this.rangeKeyName = rangeKey; + this.localSecondaryIndexes = localSecondaryIndexes; this.readCapacityUnits = readCapacityUnits; this.writeCapacityUnits = writeCapacityUnits; @@ -108,6 +114,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback .keySchema(keySchemaElements) .attributeDefinitions(attributeDefinitions.isEmpty() ? null : attributeDefinitions) .globalSecondaryIndexes(globalSecondaryIndexes.isEmpty() ? null : globalSecondaryIndexes) + .localSecondaryIndexes(localSecondaryIndexes.isEmpty() ? null : localSecondaryIndexes) .provisionedThroughput(ProvisionedThroughput.builder() .readCapacityUnits(readCapacityUnits) .writeCapacityUnits(writeCapacityUnits) @@ -150,7 +157,8 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback .build(); } - static class DynamoDbExtensionBuilder { + public static class DynamoDbExtensionBuilder { + private String tableName = DEFAULT_TABLE_NAME; private String hashKey; @@ -158,6 +166,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback private List attributeDefinitions = new ArrayList<>(); private List globalSecondaryIndexes = new ArrayList<>(); + private List localSecondaryIndexes = new ArrayList<>(); private long readCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.readCapacityUnits(); private long writeCapacityUnits = DEFAULT_PROVISIONED_THROUGHPUT.writeCapacityUnits(); @@ -166,22 +175,22 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback } - DynamoDbExtensionBuilder tableName(String databaseName) { + public DynamoDbExtensionBuilder tableName(String databaseName) { this.tableName = databaseName; return this; } - DynamoDbExtensionBuilder hashKey(String hashKey) { + public DynamoDbExtensionBuilder hashKey(String hashKey) { this.hashKey = hashKey; return this; } - DynamoDbExtensionBuilder rangeKey(String rangeKey) { + public DynamoDbExtensionBuilder rangeKey(String rangeKey) { this.rangeKey = rangeKey; return this; } - DynamoDbExtensionBuilder attributeDefinition(AttributeDefinition attributeDefinition) { + public DynamoDbExtensionBuilder attributeDefinition(AttributeDefinition attributeDefinition) { attributeDefinitions.add(attributeDefinition); return this; } @@ -191,9 +200,14 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback return this; } - DynamoDbExtension build() { + public DynamoDbExtensionBuilder localSecondaryIndex(LocalSecondaryIndex index) { + localSecondaryIndexes.add(index); + return this; + } + + public DynamoDbExtension build() { return new DynamoDbExtension(tableName, hashKey, rangeKey, - attributeDefinitions, globalSecondaryIndexes, readCapacityUnits, writeCapacityUnits); + attributeDefinitions, globalSecondaryIndexes, localSecondaryIndexes, readCapacityUnits, writeCapacityUnits); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 9a20dae0e..0f87a55d8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -5,7 +5,8 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTimeout; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -24,158 +25,159 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; -import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; import org.whispersystems.textsecuregcm.util.AttributeValues; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; -public class MessagePersisterIntegrationTest extends AbstractRedisClusterTest { +class MessagePersisterIntegrationTest { - @Rule - public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule(); + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build(); - private ExecutorService notificationExecutorService; - private MessagesCache messagesCache; - private MessagesManager messagesManager; - private MessagePersister messagePersister; - private Account account; + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - private static final Duration PERSIST_DELAY = Duration.ofMinutes(10); + private ExecutorService notificationExecutorService; + private MessagesCache messagesCache; + private MessagesManager messagesManager; + private MessagePersister messagePersister; + private Account account; - @Before - @Override - public void setUp() throws Exception { - super.setUp(); + private static final Duration PERSIST_DELAY = Duration.ofMinutes(10); - getRedisCluster().useCluster(connection -> { - connection.sync().flushall(); - connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); - }); + @BeforeEach + void setUp() throws Exception { + REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> { + connection.sync().flushall(); + connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); + }); - final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7)); - final AccountsManager accountsManager = mock(AccountsManager.class); - final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), + MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14)); + final AccountsManager accountsManager = mock(AccountsManager.class); + final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); - notificationExecutorService = Executors.newSingleThreadExecutor(); - messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); - messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), mock(ReportMessageManager.class)); - messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); + notificationExecutorService = Executors.newSingleThreadExecutor(); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); + messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), + mock(ReportMessageManager.class)); + messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, + dynamicConfigurationManager, PERSIST_DELAY); - account = mock(Account.class); + account = mock(Account.class); - final UUID accountUuid = UUID.randomUUID(); + final UUID accountUuid = UUID.randomUUID(); - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account)); - when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(accountsManager.get(accountUuid)).thenReturn(Optional.of(account)); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); - messagesCache.start(); - } + messagesCache.start(); + } - @After - @Override - public void tearDown() throws Exception { - super.tearDown(); + @AfterEach + void tearDown() throws Exception { + notificationExecutorService.shutdown(); + notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS); + } - notificationExecutorService.shutdown(); - notificationExecutorService.awaitTermination(15, TimeUnit.SECONDS); - } + @Test + void testScheduledPersistMessages() { - @Test(timeout = 15_000) - public void testScheduledPersistMessages() throws Exception { - final int messageCount = 377; - final List expectedMessages = new ArrayList<>(messageCount); - final Instant now = Instant.now(); + final int messageCount = 377; + final List expectedMessages = new ArrayList<>(messageCount); + final Instant now = Instant.now(); - for (int i = 0; i < messageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - final long timestamp = now.minus(PERSIST_DELAY.multipliedBy(2)).toEpochMilli() + i; + assertTimeout(Duration.ofSeconds(15), () -> { - final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final long timestamp = now.minus(PERSIST_DELAY.multipliedBy(2)).toEpochMilli() + i; - messagesCache.insert(messageGuid, account.getUuid(), 1, message); - expectedMessages.add(message); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); + + messagesCache.insert(messageGuid, account.getUuid(), 1, message); + expectedMessages.add(message); + } + + REDIS_CLUSTER_EXTENSION.getRedisCluster() + .useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, + String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1))); + + final AtomicBoolean messagesPersisted = new AtomicBoolean(false); + + messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() { + @Override + public void handleNewMessagesAvailable() { } - getRedisCluster().useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1))); - - final AtomicBoolean messagesPersisted = new AtomicBoolean(false); - - messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() { - @Override - public void handleNewMessagesAvailable() { - } - - @Override - public void handleNewEphemeralMessageAvailable() { - } - - @Override - public void handleMessagesPersisted() { - synchronized (messagesPersisted) { - messagesPersisted.set(true); - messagesPersisted.notifyAll(); - } - } - }); - - messagePersister.start(); - - synchronized (messagesPersisted) { - while (!messagesPersisted.get()) { - messagesPersisted.wait(); - } + @Override + public void handleNewEphemeralMessageAvailable() { } - messagePersister.stop(); + @Override + public void handleMessagesPersisted() { + synchronized (messagesPersisted) { + messagesPersisted.set(true); + messagesPersisted.notifyAll(); + } + } + }); - final List persistedMessages = new ArrayList<>(messageCount); + messagePersister.start(); - DynamoDbClient dynamoDB = messagesDynamoDbRule.getDynamoDbClient(); + synchronized (messagesPersisted) { + while (!messagesPersisted.get()) { + messagesPersisted.wait(); + } + } + + messagePersister.stop(); + + final List persistedMessages = new ArrayList<>(messageCount); + + DynamoDbClient dynamoDB = dynamoDbExtension.getDynamoDbClient(); for (Map item : dynamoDB - .scan(ScanRequest.builder().tableName(MessagesDynamoDbRule.TABLE_NAME).build()).items()) { + .scan(ScanRequest.builder().tableName(MessagesDynamoDbExtension.TABLE_NAME).build()).items()) { persistedMessages.add(MessageProtos.Envelope.newBuilder() .setServerGuid(AttributeValues.getUUID(item, "U", null).toString()) - .setType(MessageProtos.Envelope.Type.valueOf(AttributeValues.getInt(item, "T", -1))) + .setType(Type.forNumber(AttributeValues.getInt(item, "T", -1))) .setTimestamp(AttributeValues.getLong(item, "TS", -1)) .setServerTimestamp(extractServerTimestamp(AttributeValues.getByteArray(item, "S", null))) .setContent(ByteString.copyFrom(AttributeValues.getByteArray(item, "C", null))) .build()); - } + } - assertEquals(expectedMessages, persistedMessages); - } + assertEquals(expectedMessages, persistedMessages); + }); + } - private static UUID convertBinaryToUuid(byte[] bytes) { - ByteBuffer bb = ByteBuffer.wrap(bytes); - long msb = bb.getLong(); - long lsb = bb.getLong(); - return new UUID(msb, lsb); - } + private static long extractServerTimestamp(byte[] bytes) { + ByteBuffer bb = ByteBuffer.wrap(bytes); + bb.getLong(); + return bb.getLong(); + } - private static long extractServerTimestamp(byte[] bytes) { - ByteBuffer bb = ByteBuffer.wrap(bytes); - bb.getLong(); - return bb.getLong(); - } - - private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long timestamp) { - return MessageProtos.Envelope.newBuilder() - .setTimestamp(timestamp) - .setServerTimestamp(timestamp) - .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) - .setType(MessageProtos.Envelope.Type.CIPHERTEXT) - .setServerGuid(messageGuid.toString()) - .build(); - } + private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final long timestamp) { + return MessageProtos.Envelope.newBuilder() + .setTimestamp(timestamp) + .setServerTimestamp(timestamp) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(messageGuid.toString()) + .build(); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java index b197ab0e9..9c4158c7e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesDynamoDbTest.java @@ -13,15 +13,18 @@ import java.util.List; import java.util.Random; import java.util.UUID; import java.util.function.Consumer; -import org.junit.Before; -import org.junit.ClassRule; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; -import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; +import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; + +class MessagesDynamoDbTest { + -public class MessagesDynamoDbTest { private static final Random random = new Random(); private static final MessageProtos.Envelope MESSAGE1; private static final MessageProtos.Envelope MESSAGE2; @@ -61,27 +64,31 @@ public class MessagesDynamoDbTest { private MessagesDynamoDb messagesDynamoDb; - @ClassRule - public static MessagesDynamoDbRule dynamoDbRule = new MessagesDynamoDbRule(); - @Before - public void setup() { - messagesDynamoDb = new MessagesDynamoDb(dynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7)); + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build(); + + @BeforeEach + void setup() { + messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME, + Duration.ofDays(14)); } @Test - public void testServerStart() { + void testServerStart() { } @Test - public void testSimpleFetchAfterInsert() { + void testSimpleFetchAfterInsert() { final UUID destinationUuid = UUID.randomUUID(); final int destinationDeviceId = random.nextInt(255) + 1; messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); - final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); + final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId, + MessagesDynamoDb.RESULT_SET_CHUNK_SIZE); assertThat(messagesStored).isNotNull().hasSize(3); - final MessageProtos.Envelope firstMessage = MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3; + final MessageProtos.Envelope firstMessage = + MESSAGE1.getServerGuid().compareTo(MESSAGE3.getServerGuid()) < 0 ? MESSAGE1 : MESSAGE3; final MessageProtos.Envelope secondMessage = firstMessage == MESSAGE1 ? MESSAGE3 : MESSAGE1; assertThat(messagesStored).element(0).satisfies(verify(firstMessage)); assertThat(messagesStored).element(1).satisfies(verify(secondMessage)); @@ -89,61 +96,76 @@ public class MessagesDynamoDbTest { } @Test - public void testDeleteForDestination() { + void testDeleteForDestination() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); + assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE1)); + assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE3)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1).element(0).satisfies(verify(MESSAGE2)); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1).element(0).satisfies(verify(MESSAGE2)); } @Test - public void testDeleteForDestinationDevice() { + void testDeleteForDestinationDevice() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); + assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE1)); + assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE3)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1).element(0).satisfies(verify(MESSAGE2)); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); + assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE1)); assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1).element(0).satisfies(verify(MESSAGE2)); } @Test - public void testDeleteMessageByDestinationAndGuid() { + void testDeleteMessageByDestinationAndGuid() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE2)); + assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE1)); + assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE3)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1).element(0).satisfies(verify(MESSAGE2)); messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid, UUID.fromString(MESSAGE2.getServerGuid())); - assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE1)); - assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1).element(0).satisfies(verify(MESSAGE3)); - assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE1)); + assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + .element(0).satisfies(verify(MESSAGE3)); + assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .isEmpty(); } private static void verify(OutgoingMessageEntity retrieved, MessageProtos.Envelope inserted) { @@ -164,6 +186,7 @@ public class MessagesDynamoDbTest { } private static final class VerifyMessage implements Consumer { + private final MessageProtos.Envelope expected; public VerifyMessage(MessageProtos.Envelope expected) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbRule.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbExtension.java similarity index 57% rename from service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbRule.java rename to service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbExtension.java index 62c4ed2ce..1aeba95e1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbRule.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessagesDynamoDbExtension.java @@ -5,42 +5,36 @@ package org.whispersystems.textsecuregcm.tests.util; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; -import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; import software.amazon.awssdk.services.dynamodb.model.KeyType; import software.amazon.awssdk.services.dynamodb.model.LocalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.Projection; import software.amazon.awssdk.services.dynamodb.model.ProjectionType; -import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput; import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; -public class MessagesDynamoDbRule extends LocalDynamoDbRule { +public class MessagesDynamoDbExtension { public static final String TABLE_NAME = "Signal_Messages_UnitTest"; - @Override - protected void before() throws Throwable { - super.before(); - getDynamoDbClient().createTable(CreateTableRequest.builder() + public static DynamoDbExtension build() { + return DynamoDbExtension.builder() .tableName(TABLE_NAME) - .keySchema(KeySchemaElement.builder().attributeName("H").keyType(KeyType.HASH).build(), - KeySchemaElement.builder().attributeName("S").keyType(KeyType.RANGE).build()) - .attributeDefinitions( - AttributeDefinition.builder().attributeName("H").attributeType(ScalarAttributeType.B).build(), - AttributeDefinition.builder().attributeName("S").attributeType(ScalarAttributeType.B).build(), + .hashKey("H") + .rangeKey("S") + .attributeDefinition( + AttributeDefinition.builder().attributeName("H").attributeType(ScalarAttributeType.B).build()) + .attributeDefinition( + AttributeDefinition.builder().attributeName("S").attributeType(ScalarAttributeType.B).build()) + .attributeDefinition( AttributeDefinition.builder().attributeName("U").attributeType(ScalarAttributeType.B).build()) - .provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(20L).writeCapacityUnits(20L).build()) - .localSecondaryIndexes(LocalSecondaryIndex.builder().indexName("Message_UUID_Index") + .localSecondaryIndex(LocalSecondaryIndex.builder().indexName("Message_UUID_Index") .keySchema(KeySchemaElement.builder().attributeName("H").keyType(KeyType.HASH).build(), KeySchemaElement.builder().attributeName("U").keyType(KeyType.RANGE).build()) .projection(Projection.builder().projectionType(ProjectionType.KEYS_ONLY).build()) .build()) - .build()); + .build(); } - @Override - protected void after() { - super.after(); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 78de5abe1..671ba43c8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -5,9 +5,10 @@ package org.whispersystems.textsecuregcm.websocket; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTimeout; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.eq; @@ -34,10 +35,10 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -45,195 +46,208 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; -import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; +import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest { +class WebSocketConnectionIntegrationTest { - @Rule - public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule(); + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = MessagesDynamoDbExtension.build(); - private ExecutorService executorService; - private MessagesDynamoDb messagesDynamoDb; - private MessagesCache messagesCache; - private ReportMessageManager reportMessageManager; - private Account account; - private Device device; - private WebSocketClient webSocketClient; - private WebSocketConnection webSocketConnection; - private ScheduledExecutorService retrySchedulingExecutor; + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - private long serialTimestamp = System.currentTimeMillis(); + private ExecutorService executorService; + private MessagesDynamoDb messagesDynamoDb; + private MessagesCache messagesCache; + private ReportMessageManager reportMessageManager; + private Account account; + private Device device; + private WebSocketClient webSocketClient; + private WebSocketConnection webSocketConnection; + private ScheduledExecutorService retrySchedulingExecutor; - @Before - @Override - public void setUp() throws Exception { - super.setUp(); + private long serialTimestamp = System.currentTimeMillis(); - executorService = Executors.newSingleThreadExecutor(); - messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), executorService); - messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDbClient(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7)); - reportMessageManager = mock(ReportMessageManager.class); - account = mock(Account.class); - device = mock(Device.class); - webSocketClient = mock(WebSocketClient.class); - retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor(); + @BeforeEach + void setUp() throws Exception { - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(1L); + executorService = Executors.newSingleThreadExecutor(); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService); + messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME, + Duration.ofDays(7)); + reportMessageManager = mock(ReportMessageManager.class); + account = mock(Account.class); + device = mock(Device.class); + webSocketClient = mock(WebSocketClient.class); + retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor(); - webSocketConnection = new WebSocketConnection( - mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), - new AuthenticatedAccount(() -> new Pair<>(account, device)), - device, - webSocketClient, - retrySchedulingExecutor); - } + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); - @After - @Override - public void tearDown() throws Exception { - executorService.shutdown(); - executorService.awaitTermination(2, TimeUnit.SECONDS); + webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), + new AuthenticatedAccount(() -> new Pair<>(account, device)), + device, + webSocketClient, + retrySchedulingExecutor); + } - retrySchedulingExecutor.shutdown(); - retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS); + @AfterEach + void tearDown() throws Exception { + executorService.shutdown(); + executorService.awaitTermination(2, TimeUnit.SECONDS); - super.tearDown(); - } + retrySchedulingExecutor.shutdown(); + retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS); + } - @Test(timeout = 15_000) - public void testProcessStoredMessages() throws InterruptedException { - final int persistedMessageCount = 207; - final int cachedMessageCount = 173; + @Test + void testProcessStoredMessages() { + final int persistedMessageCount = 207; + final int cachedMessageCount = 173; - final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); + final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); - { - final List persistedMessages = new ArrayList<>(persistedMessageCount); + assertTimeout(Duration.ofSeconds(15), () -> { - for (int i = 0; i < persistedMessageCount; i++) { - final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); - persistedMessages.add(envelope); - expectedMessages.add(envelope); - } + for (int i = 0; i < persistedMessageCount; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); - messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + persistedMessages.add(envelope); + expectedMessages.add(envelope); } - for (int i = 0; i < cachedMessageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + } - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); - expectedMessages.add(envelope); - } + for (int i = 0; i < cachedMessageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - final AtomicBoolean queueCleared = new AtomicBoolean(false); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + expectedMessages.add(envelope); + } - when(successResponse.getStatus()).thenReturn(200); - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + final AtomicBoolean queueCleared = new AtomicBoolean(false); - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer((Answer>)invocation -> { + when(successResponse.getStatus()).thenReturn(200); + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( + CompletableFuture.completedFuture(successResponse)); + + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( + (Answer>) invocation -> { synchronized (queueCleared) { - queueCleared.set(true); - queueCleared.notifyAll(); + queueCleared.set(true); + queueCleared.notifyAll(); } return CompletableFuture.completedFuture(successResponse); + }); + + webSocketConnection.processStoredMessages(); + + synchronized (queueCleared) { + while (!queueCleared.get()) { + queueCleared.wait(); + } + } + + @SuppressWarnings("unchecked") final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass( + Optional.class); + + verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), + eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + + final List sentMessages = new ArrayList<>(); + + for (final Optional maybeMessageBody : messageBodyCaptor.getAllValues()) { + maybeMessageBody.ifPresent(messageBytes -> { + try { + sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes)); + } catch (final InvalidProtocolBufferException e) { + fail("Could not parse sent message"); + } }); + } - webSocketConnection.processStoredMessages(); + assertEquals(expectedMessages, sentMessages); + }); + } - synchronized (queueCleared) { - while (!queueCleared.get()) { - queueCleared.wait(); - } + @Test + void testProcessStoredMessagesClientClosed() { + final int persistedMessageCount = 207; + final int cachedMessageCount = 173; + + final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); + + assertTimeout(Duration.ofSeconds(15), () -> { + + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); + + for (int i = 0; i < persistedMessageCount; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + persistedMessages.add(envelope); + expectedMessages.add(envelope); } - @SuppressWarnings("unchecked") - final ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + } - verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); - verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + for (int i = 0; i < cachedMessageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); - final List sentMessages = new ArrayList<>(); + expectedMessages.add(envelope); + } - for (final Optional maybeMessageBody : messageBodyCaptor.getAllValues()) { - maybeMessageBody.ifPresent(messageBytes -> { - try { - sentMessages.add(MessageProtos.Envelope.parseFrom(messageBytes)); - } catch (final InvalidProtocolBufferException e) { - fail("Could not parse sent message"); - } - }); - } + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn( + CompletableFuture.failedFuture(new IOException("Connection closed"))); - assertEquals(expectedMessages, sentMessages); - } + webSocketConnection.processStoredMessages(); - @Test(timeout = 15_000) - public void testProcessStoredMessagesClientClosed() { - final int persistedMessageCount = 207; - final int cachedMessageCount = 173; + //noinspection unchecked + ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); - final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); - - { - final List persistedMessages = new ArrayList<>(persistedMessageCount); - - for (int i = 0; i < persistedMessageCount; i++) { - final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); - persistedMessages.add(envelope); - expectedMessages.add(envelope); - } - - messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); - } - - for (int i = 0; i < cachedMessageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); - - expectedMessages.add(envelope); - } - - when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(CompletableFuture.failedFuture(new IOException("Connection closed"))); - - webSocketConnection.processStoredMessages(); - - //noinspection unchecked - ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); - - verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); - verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), + eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + verify(webSocketClient, never()).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), + eq(Optional.empty())); final List sentMessages = messageBodyCaptor.getAllValues().stream() .map(Optional::get) .map(messageBytes -> { - try { - return Envelope.parseFrom(messageBytes); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } + try { + return Envelope.parseFrom(messageBytes); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } }) .collect(Collectors.toList()); assertTrue(expectedMessages.containsAll(sentMessages)); + }); } private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index ab0876243..8233e2346 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -65,18 +65,19 @@ public class WebSocketResourceProviderFactory extends WebSo } } - return new WebSocketResourceProvider(getRemoteAddress(request), - this.jerseyApplicationHandler, - this.environment.getRequestLog(), - authenticated, - this.environment.getMessageFactory(), - ofNullable(this.environment.getConnectListener()), - this.environment.getIdleTimeoutMillis()); + return new WebSocketResourceProvider<>(getRemoteAddress(request), + this.jerseyApplicationHandler, + this.environment.getRequestLog(), + authenticated, + this.environment.getMessageFactory(), + ofNullable(this.environment.getConnectListener()), + this.environment.getIdleTimeoutMillis()); } catch (AuthenticationException | IOException e) { logger.warn("Authentication failure", e); try { response.sendError(500, "Failure"); - } catch (IOException ex) {} + } catch (IOException ignored) { + } return null; } }