Switch SQS to Amazon SDKv2.

This commit is contained in:
Graeme Connell 2021-06-30 09:11:38 -06:00 committed by gram-signal
parent be6ef76486
commit 42ff3f8432
3 changed files with 60 additions and 59 deletions

View File

@ -259,6 +259,10 @@
<groupId>software.amazon.awssdk</groupId> <groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId> <artifactId>s3</artifactId>
</dependency> </dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sqs</artifactId>
</dependency>
<dependency> <dependency>
<groupId>software.amazon.awssdk</groupId> <groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId> <artifactId>dynamodb</artifactId>
@ -271,10 +275,6 @@
<groupId>com.amazonaws</groupId> <groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-s3</artifactId> <artifactId>aws-java-sdk-s3</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sqs</artifactId>
</dependency>
<dependency> <dependency>
<groupId>com.amazonaws</groupId> <groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-appconfig</artifactId> <artifactId>aws-java-sdk-appconfig</artifactId>

View File

@ -6,16 +6,6 @@ package org.whispersystems.textsecuregcm.sqs;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
@ -32,6 +22,15 @@ import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
public class DirectoryQueue { public class DirectoryQueue {
@ -43,18 +42,21 @@ public class DirectoryQueue {
private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch")); private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
private final List<String> queueUrls; private final List<String> queueUrls;
private final AmazonSQS sqs; private final SqsClient sqs;
public DirectoryQueue(SqsConfiguration sqsConfig) { public DirectoryQueue(SqsConfiguration sqsConfig) {
final AWSCredentials credentials = new BasicAWSCredentials(sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()); StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create(
final AWSStaticCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials); sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()));
this.queueUrls = sqsConfig.getQueueUrls(); this.queueUrls = sqsConfig.getQueueUrls();
this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build(); this.sqs = SqsClient.builder()
.region(Region.of(sqsConfig.getRegion()))
.credentialsProvider(credentialsProvider)
.build();
} }
@VisibleForTesting @VisibleForTesting
DirectoryQueue(final List<String> queueUrls, final AmazonSQS sqs) { DirectoryQueue(final List<String> queueUrls, final SqsClient sqs) {
this.queueUrls = queueUrls; this.queueUrls = queueUrls;
this.sqs = sqs; this.sqs = sqs;
} }
@ -82,28 +84,30 @@ public class DirectoryQueue {
final Account account = pair.first(); final Account account = pair.first();
final String action = pair.second(); final String action = pair.second();
return new SendMessageBatchRequestEntry() return SendMessageBatchRequestEntry.builder()
.withMessageBody("-") .messageBody("-")
.withId(UUID.randomUUID().toString()) .id(UUID.randomUUID().toString())
.withMessageDeduplicationId(UUID.randomUUID().toString()) .messageDeduplicationId(UUID.randomUUID().toString())
.withMessageGroupId(account.getNumber()) .messageGroupId(account.getNumber())
.withMessageAttributes(Map.of( .messageAttributes(Map.of(
"id", new MessageAttributeValue().withDataType("String").withStringValue(account.getNumber()), "id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
"uuid", new MessageAttributeValue().withDataType("String").withStringValue(account.getUuid().toString()), "uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
"action", new MessageAttributeValue().withDataType("String").withStringValue(action) "action", MessageAttributeValue.builder().dataType("String").stringValue(action).build()
)); ))
.build();
}).collect(Collectors.toList()); }).collect(Collectors.toList());
final SendMessageBatchRequest sendMessageBatchRequest = new SendMessageBatchRequest() final SendMessageBatchRequest sendMessageBatchRequest = SendMessageBatchRequest.builder()
.withQueueUrl(queueUrl) .queueUrl(queueUrl)
.withEntries(entries); .entries(entries)
.build();
try (final Timer.Context ignored = sendMessageBatchTimer.time()) { try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
sqs.sendMessageBatch(sendMessageBatchRequest); sqs.sendMessageBatch(sendMessageBatchRequest);
} catch (AmazonServiceException ex) { } catch (SdkServiceException ex) {
serviceErrorMeter.mark(); serviceErrorMeter.mark();
logger.warn("sqs service error: ", ex); logger.warn("sqs service error: ", ex);
} catch (AmazonClientException ex) { } catch (SdkClientException ex) {
clientErrorMeter.mark(); clientErrorMeter.mark();
logger.warn("sqs client error: ", ex); logger.warn("sqs client error: ", ex);
} catch (Throwable t) { } catch (Throwable t) {
@ -112,5 +116,4 @@ public class DirectoryQueue {
} }
} }
} }
} }

View File

@ -5,17 +5,15 @@
package org.whispersystems.textsecuregcm.sqs; package org.whispersystems.textsecuregcm.sqs;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import junitparams.JUnitParamsRunner; import junitparams.JUnitParamsRunner;
import junitparams.Parameters; import junitparams.Parameters;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -33,7 +31,7 @@ public class DirectoryQueueTest {
@Test @Test
@Parameters(method = "argumentsForTestRefreshRegisteredUser") @Parameters(method = "argumentsForTestRefreshRegisteredUser")
public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) { public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) {
final AmazonSQS sqs = mock(AmazonSQS.class); final SqsClient sqs = mock(SqsClient.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
final Account account = mock(Account.class); final Account account = mock(Account.class);
@ -47,15 +45,15 @@ public class DirectoryQueueTest {
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessageBatch(requestCaptor.capture()); verify(sqs).sendMessageBatch(requestCaptor.capture());
assertEquals(1, requestCaptor.getValue().getEntries().size()); assertEquals(1, requestCaptor.getValue().entries().size());
final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes(); final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().entries().get(0).messageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(), messageAttributes.get("action"));
} }
@Test @Test
public void testRefreshBatch() { public void testRefreshBatch() {
final AmazonSQS sqs = mock(AmazonSQS.class); final SqsClient sqs = mock(SqsClient.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
final Account discoverableAccount = mock(Account.class); final Account discoverableAccount = mock(Account.class);
@ -75,22 +73,22 @@ public class DirectoryQueueTest {
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessageBatch(requestCaptor.capture()); verify(sqs).sendMessageBatch(requestCaptor.capture());
assertEquals(2, requestCaptor.getValue().getEntries().size()); assertEquals(2, requestCaptor.getValue().entries().size());
final Map<String, MessageAttributeValue> discoverableAccountAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes(); final Map<String, MessageAttributeValue> discoverableAccountAttributes = requestCaptor.getValue().entries().get(0).messageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getNumber()), discoverableAccountAttributes.get("id")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getNumber()).build(), discoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getUuid().toString()), discoverableAccountAttributes.get("uuid")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getUuid().toString()).build(), discoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), discoverableAccountAttributes.get("action")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), discoverableAccountAttributes.get("action"));
final Map<String, MessageAttributeValue> undiscoverableAccountAttributes = requestCaptor.getValue().getEntries().get(1).getMessageAttributes(); final Map<String, MessageAttributeValue> undiscoverableAccountAttributes = requestCaptor.getValue().entries().get(1).messageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getNumber()), undiscoverableAccountAttributes.get("id")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getNumber()).build(), undiscoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getUuid().toString()), undiscoverableAccountAttributes.get("uuid")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getUuid().toString()).build(), undiscoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("delete"), undiscoverableAccountAttributes.get("action")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("delete").build(), undiscoverableAccountAttributes.get("action"));
} }
@Test @Test
public void testSendMessageMultipleQueues() { public void testSendMessageMultipleQueues() {
final AmazonSQS sqs = mock(AmazonSQS.class); final SqsClient sqs = mock(SqsClient.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs); final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs);
final Account account = mock(Account.class); final Account account = mock(Account.class);
@ -105,10 +103,10 @@ public class DirectoryQueueTest {
verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture()); verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture());
for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) { for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) {
assertEquals(1, requestCaptor.getValue().getEntries().size()); assertEquals(1, requestCaptor.getValue().entries().size());
final Map<String, MessageAttributeValue> messageAttributes = sendMessageBatchRequest.getEntries().get(0).getMessageAttributes(); final Map<String, MessageAttributeValue> messageAttributes = sendMessageBatchRequest.entries().get(0).messageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action")); assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), messageAttributes.get("action"));
} }
} }