From 34dbff6786f8bfd57edf26b837b93c11e7e39dab Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 26 Jul 2021 11:40:53 -0400 Subject: [PATCH] Switch to an async SQS client. --- .../textsecuregcm/sqs/DirectoryQueue.java | 104 +++++++++--------- .../textsecuregcm/sqs/DirectoryQueueTest.java | 101 ++++++----------- 2 files changed, 87 insertions(+), 118 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java index e26265c9e..9b844e69a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java @@ -14,23 +14,19 @@ import com.google.common.annotations.VisibleForTesting; import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.stream.Collectors; -import com.google.common.collect.Iterables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.SqsConfiguration; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Constants; -import org.whispersystems.textsecuregcm.util.Pair; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkServiceException; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; -import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; -import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; public class DirectoryQueue { @@ -41,22 +37,38 @@ public class DirectoryQueue { private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch")); - private final List queueUrls; - private final SqsClient sqs; + private final List queueUrls; + private final SqsAsyncClient sqs; + + private enum UpdateAction { + ADD("add"), + DELETE("delete"); + + private final String action; + + UpdateAction(final String action) { + this.action = action; + } + + public MessageAttributeValue toMessageAttributeValue() { + return MessageAttributeValue.builder().dataType("String").stringValue(action).build(); + } + } public DirectoryQueue(SqsConfiguration sqsConfig) { StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create( sqsConfig.getAccessKey(), sqsConfig.getAccessSecret())); this.queueUrls = sqsConfig.getQueueUrls(); - this.sqs = SqsClient.builder() + + this.sqs = SqsAsyncClient.builder() .region(Region.of(sqsConfig.getRegion())) .credentialsProvider(credentialsProvider) .build(); } @VisibleForTesting - DirectoryQueue(final List queueUrls, final SqsClient sqs) { + DirectoryQueue(final List queueUrls, final SqsAsyncClient sqs) { this.queueUrls = queueUrls; this.sqs = sqs; } @@ -66,58 +78,44 @@ public class DirectoryQueue { } public void refreshAccount(final Account account) { - refreshAccounts(List.of(account)); - } - - public void refreshAccounts(final List accounts) { - final List> accountsAndActions = accounts.stream() - .map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete")) - .collect(Collectors.toList()); - - sendUpdateMessages(accountsAndActions); + sendUpdateMessage(account, isDiscoverable(account) ? UpdateAction.ADD : UpdateAction.DELETE); } public void deleteAccount(final Account account) { - sendUpdateMessages(List.of(new Pair<>(account, "delete"))); + sendUpdateMessage(account, UpdateAction.DELETE); } - private void sendUpdateMessages(final List> accountsAndActions) { + private void sendUpdateMessage(final Account account, final UpdateAction action) { for (final String queueUrl : queueUrls) { - for (final List> partition : Iterables.partition(accountsAndActions, 10)) { - final List entries = partition.stream().map(pair -> { - final Account account = pair.first(); - final String action = pair.second(); + final Timer.Context timerContext = sendMessageBatchTimer.time(); - return SendMessageBatchRequestEntry.builder() - .messageBody("-") - .id(UUID.randomUUID().toString()) - .messageDeduplicationId(UUID.randomUUID().toString()) - .messageGroupId(account.getNumber()) - .messageAttributes(Map.of( - "id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(), - "uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(), - "action", MessageAttributeValue.builder().dataType("String").stringValue(action).build() - )) - .build(); - }).collect(Collectors.toList()); + final SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(queueUrl) + .messageBody("-") + .messageDeduplicationId(UUID.randomUUID().toString()) + .messageGroupId(account.getNumber()) + .messageAttributes(Map.of( + "id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(), + "uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(), + "action", action.toMessageAttributeValue() + )) + .build(); - final SendMessageBatchRequest sendMessageBatchRequest = SendMessageBatchRequest.builder() - .queueUrl(queueUrl) - .entries(entries) - .build(); - - try (final Timer.Context ignored = sendMessageBatchTimer.time()) { - sqs.sendMessageBatch(sendMessageBatchRequest); - } catch (SdkServiceException ex) { - serviceErrorMeter.mark(); - logger.warn("sqs service error: ", ex); - } catch (SdkClientException ex) { - clientErrorMeter.mark(); - logger.warn("sqs client error: ", ex); - } catch (Throwable t) { - logger.warn("sqs unexpected error: ", t); + sqs.sendMessage(request).whenComplete((response, cause) -> { + try { + if (cause instanceof SdkServiceException) { + serviceErrorMeter.mark(); + logger.warn("sqs service error", cause); + } else if (cause instanceof SdkClientException) { + clientErrorMeter.mark(); + logger.warn("sqs client error", cause); + } else if (cause != null) { + logger.warn("sqs unexpected error", cause); + } + } finally { + timerContext.close(); } - } + }); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java index 97889941b..8a371e7c0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java @@ -5,34 +5,45 @@ package org.whispersystems.textsecuregcm.sqs; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.storage.Account; -import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; -import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; - -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageResponse; public class DirectoryQueueTest { + private SqsAsyncClient sqsAsyncClient; + + @BeforeEach + void setUp() { + sqsAsyncClient = mock(SqsAsyncClient.class); + + when(sqsAsyncClient.sendMessage(any(SendMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(SendMessageResponse.builder().build())); + } + @ParameterizedTest @MethodSource("argumentsForTestRefreshRegisteredUser") void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) { - final SqsClient sqs = mock(SqsClient.class); - final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); + final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqsAsyncClient); final Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005556543"); @@ -42,13 +53,11 @@ public class DirectoryQueueTest { directoryQueue.refreshAccount(account); - final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); - verify(sqs).sendMessageBatch(requestCaptor.capture()); + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(sqsAsyncClient).sendMessage(requestCaptor.capture()); - assertEquals(1, requestCaptor.getValue().entries().size()); - - final Map messageAttributes = requestCaptor.getValue().entries().get(0).messageAttributes(); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(), messageAttributes.get("action")); + assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(), + requestCaptor.getValue().messageAttributes().get("action")); } @SuppressWarnings("unused") @@ -60,45 +69,9 @@ public class DirectoryQueueTest { Arguments.of(false, false, "delete")); } - @Test - void testRefreshBatch() { - final SqsClient sqs = mock(SqsClient.class); - final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); - - final Account discoverableAccount = mock(Account.class); - when(discoverableAccount.getNumber()).thenReturn("+18005556543"); - when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID()); - when(discoverableAccount.isEnabled()).thenReturn(true); - when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true); - - final Account undiscoverableAccount = mock(Account.class); - when(undiscoverableAccount.getNumber()).thenReturn("+18005550987"); - when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID()); - when(undiscoverableAccount.isEnabled()).thenReturn(true); - when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false); - - directoryQueue.refreshAccounts(List.of(discoverableAccount, undiscoverableAccount)); - - final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); - verify(sqs).sendMessageBatch(requestCaptor.capture()); - - assertEquals(2, requestCaptor.getValue().entries().size()); - - final Map discoverableAccountAttributes = requestCaptor.getValue().entries().get(0).messageAttributes(); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getNumber()).build(), discoverableAccountAttributes.get("id")); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getUuid().toString()).build(), discoverableAccountAttributes.get("uuid")); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), discoverableAccountAttributes.get("action")); - - final Map undiscoverableAccountAttributes = requestCaptor.getValue().entries().get(1).messageAttributes(); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getNumber()).build(), undiscoverableAccountAttributes.get("id")); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getUuid().toString()).build(), undiscoverableAccountAttributes.get("uuid")); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("delete").build(), undiscoverableAccountAttributes.get("action")); - } - @Test void testSendMessageMultipleQueues() { - final SqsClient sqs = mock(SqsClient.class); - final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs); + final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqsAsyncClient); final Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005556543"); @@ -108,14 +81,12 @@ public class DirectoryQueueTest { directoryQueue.refreshAccount(account); - final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); - verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture()); + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(sqsAsyncClient, times(2)).sendMessage(requestCaptor.capture()); - for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) { - assertEquals(1, requestCaptor.getValue().entries().size()); - - final Map messageAttributes = sendMessageBatchRequest.entries().get(0).messageAttributes(); - assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), messageAttributes.get("action")); + for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) { + assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), + sendMessageRequest.messageAttributes().get("action")); } } }