Switch to an async SQS client.
This commit is contained in:
parent
a6066bfc2f
commit
34dbff6786
|
@ -14,23 +14,19 @@ import com.google.common.annotations.VisibleForTesting;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
import com.google.common.collect.Iterables;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.core.exception.SdkServiceException;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.sqs.SqsClient;
|
||||
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
|
||||
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
|
||||
|
||||
public class DirectoryQueue {
|
||||
|
||||
|
@ -41,22 +37,38 @@ public class DirectoryQueue {
|
|||
private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
|
||||
private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
|
||||
|
||||
private final List<String> queueUrls;
|
||||
private final SqsClient sqs;
|
||||
private final List<String> queueUrls;
|
||||
private final SqsAsyncClient sqs;
|
||||
|
||||
private enum UpdateAction {
|
||||
ADD("add"),
|
||||
DELETE("delete");
|
||||
|
||||
private final String action;
|
||||
|
||||
UpdateAction(final String action) {
|
||||
this.action = action;
|
||||
}
|
||||
|
||||
public MessageAttributeValue toMessageAttributeValue() {
|
||||
return MessageAttributeValue.builder().dataType("String").stringValue(action).build();
|
||||
}
|
||||
}
|
||||
|
||||
public DirectoryQueue(SqsConfiguration sqsConfig) {
|
||||
StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create(
|
||||
sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()));
|
||||
|
||||
this.queueUrls = sqsConfig.getQueueUrls();
|
||||
this.sqs = SqsClient.builder()
|
||||
|
||||
this.sqs = SqsAsyncClient.builder()
|
||||
.region(Region.of(sqsConfig.getRegion()))
|
||||
.credentialsProvider(credentialsProvider)
|
||||
.build();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
DirectoryQueue(final List<String> queueUrls, final SqsClient sqs) {
|
||||
DirectoryQueue(final List<String> queueUrls, final SqsAsyncClient sqs) {
|
||||
this.queueUrls = queueUrls;
|
||||
this.sqs = sqs;
|
||||
}
|
||||
|
@ -66,58 +78,44 @@ public class DirectoryQueue {
|
|||
}
|
||||
|
||||
public void refreshAccount(final Account account) {
|
||||
refreshAccounts(List.of(account));
|
||||
}
|
||||
|
||||
public void refreshAccounts(final List<Account> accounts) {
|
||||
final List<Pair<Account, String>> accountsAndActions = accounts.stream()
|
||||
.map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete"))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
sendUpdateMessages(accountsAndActions);
|
||||
sendUpdateMessage(account, isDiscoverable(account) ? UpdateAction.ADD : UpdateAction.DELETE);
|
||||
}
|
||||
|
||||
public void deleteAccount(final Account account) {
|
||||
sendUpdateMessages(List.of(new Pair<>(account, "delete")));
|
||||
sendUpdateMessage(account, UpdateAction.DELETE);
|
||||
}
|
||||
|
||||
private void sendUpdateMessages(final List<Pair<Account, String>> accountsAndActions) {
|
||||
private void sendUpdateMessage(final Account account, final UpdateAction action) {
|
||||
for (final String queueUrl : queueUrls) {
|
||||
for (final List<Pair<Account, String>> partition : Iterables.partition(accountsAndActions, 10)) {
|
||||
final List<SendMessageBatchRequestEntry> entries = partition.stream().map(pair -> {
|
||||
final Account account = pair.first();
|
||||
final String action = pair.second();
|
||||
final Timer.Context timerContext = sendMessageBatchTimer.time();
|
||||
|
||||
return SendMessageBatchRequestEntry.builder()
|
||||
.messageBody("-")
|
||||
.id(UUID.randomUUID().toString())
|
||||
.messageDeduplicationId(UUID.randomUUID().toString())
|
||||
.messageGroupId(account.getNumber())
|
||||
.messageAttributes(Map.of(
|
||||
"id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
|
||||
"uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
|
||||
"action", MessageAttributeValue.builder().dataType("String").stringValue(action).build()
|
||||
))
|
||||
.build();
|
||||
}).collect(Collectors.toList());
|
||||
final SendMessageRequest request = SendMessageRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.messageBody("-")
|
||||
.messageDeduplicationId(UUID.randomUUID().toString())
|
||||
.messageGroupId(account.getNumber())
|
||||
.messageAttributes(Map.of(
|
||||
"id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
|
||||
"uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
|
||||
"action", action.toMessageAttributeValue()
|
||||
))
|
||||
.build();
|
||||
|
||||
final SendMessageBatchRequest sendMessageBatchRequest = SendMessageBatchRequest.builder()
|
||||
.queueUrl(queueUrl)
|
||||
.entries(entries)
|
||||
.build();
|
||||
|
||||
try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
|
||||
sqs.sendMessageBatch(sendMessageBatchRequest);
|
||||
} catch (SdkServiceException ex) {
|
||||
serviceErrorMeter.mark();
|
||||
logger.warn("sqs service error: ", ex);
|
||||
} catch (SdkClientException ex) {
|
||||
clientErrorMeter.mark();
|
||||
logger.warn("sqs client error: ", ex);
|
||||
} catch (Throwable t) {
|
||||
logger.warn("sqs unexpected error: ", t);
|
||||
sqs.sendMessage(request).whenComplete((response, cause) -> {
|
||||
try {
|
||||
if (cause instanceof SdkServiceException) {
|
||||
serviceErrorMeter.mark();
|
||||
logger.warn("sqs service error", cause);
|
||||
} else if (cause instanceof SdkClientException) {
|
||||
clientErrorMeter.mark();
|
||||
logger.warn("sqs client error", cause);
|
||||
} else if (cause != null) {
|
||||
logger.warn("sqs unexpected error", cause);
|
||||
}
|
||||
} finally {
|
||||
timerContext.close();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,34 +5,45 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.sqs;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import software.amazon.awssdk.services.sqs.SqsClient;
|
||||
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
|
||||
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
|
||||
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
|
||||
|
||||
public class DirectoryQueueTest {
|
||||
|
||||
private SqsAsyncClient sqsAsyncClient;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
sqsAsyncClient = mock(SqsAsyncClient.class);
|
||||
|
||||
when(sqsAsyncClient.sendMessage(any(SendMessageRequest.class)))
|
||||
.thenReturn(CompletableFuture.completedFuture(SendMessageResponse.builder().build()));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("argumentsForTestRefreshRegisteredUser")
|
||||
void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) {
|
||||
final SqsClient sqs = mock(SqsClient.class);
|
||||
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
|
||||
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqsAsyncClient);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005556543");
|
||||
|
@ -42,13 +53,11 @@ public class DirectoryQueueTest {
|
|||
|
||||
directoryQueue.refreshAccount(account);
|
||||
|
||||
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
|
||||
verify(sqs).sendMessageBatch(requestCaptor.capture());
|
||||
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
|
||||
verify(sqsAsyncClient).sendMessage(requestCaptor.capture());
|
||||
|
||||
assertEquals(1, requestCaptor.getValue().entries().size());
|
||||
|
||||
final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().entries().get(0).messageAttributes();
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(), messageAttributes.get("action"));
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(),
|
||||
requestCaptor.getValue().messageAttributes().get("action"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unused")
|
||||
|
@ -60,45 +69,9 @@ public class DirectoryQueueTest {
|
|||
Arguments.of(false, false, "delete"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRefreshBatch() {
|
||||
final SqsClient sqs = mock(SqsClient.class);
|
||||
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
|
||||
|
||||
final Account discoverableAccount = mock(Account.class);
|
||||
when(discoverableAccount.getNumber()).thenReturn("+18005556543");
|
||||
when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(discoverableAccount.isEnabled()).thenReturn(true);
|
||||
when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true);
|
||||
|
||||
final Account undiscoverableAccount = mock(Account.class);
|
||||
when(undiscoverableAccount.getNumber()).thenReturn("+18005550987");
|
||||
when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(undiscoverableAccount.isEnabled()).thenReturn(true);
|
||||
when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false);
|
||||
|
||||
directoryQueue.refreshAccounts(List.of(discoverableAccount, undiscoverableAccount));
|
||||
|
||||
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
|
||||
verify(sqs).sendMessageBatch(requestCaptor.capture());
|
||||
|
||||
assertEquals(2, requestCaptor.getValue().entries().size());
|
||||
|
||||
final Map<String, MessageAttributeValue> discoverableAccountAttributes = requestCaptor.getValue().entries().get(0).messageAttributes();
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getNumber()).build(), discoverableAccountAttributes.get("id"));
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(discoverableAccount.getUuid().toString()).build(), discoverableAccountAttributes.get("uuid"));
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), discoverableAccountAttributes.get("action"));
|
||||
|
||||
final Map<String, MessageAttributeValue> undiscoverableAccountAttributes = requestCaptor.getValue().entries().get(1).messageAttributes();
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getNumber()).build(), undiscoverableAccountAttributes.get("id"));
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(undiscoverableAccount.getUuid().toString()).build(), undiscoverableAccountAttributes.get("uuid"));
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("delete").build(), undiscoverableAccountAttributes.get("action"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSendMessageMultipleQueues() {
|
||||
final SqsClient sqs = mock(SqsClient.class);
|
||||
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs);
|
||||
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqsAsyncClient);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005556543");
|
||||
|
@ -108,14 +81,12 @@ public class DirectoryQueueTest {
|
|||
|
||||
directoryQueue.refreshAccount(account);
|
||||
|
||||
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
|
||||
verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture());
|
||||
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
|
||||
verify(sqsAsyncClient, times(2)).sendMessage(requestCaptor.capture());
|
||||
|
||||
for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) {
|
||||
assertEquals(1, requestCaptor.getValue().entries().size());
|
||||
|
||||
final Map<String, MessageAttributeValue> messageAttributes = sendMessageBatchRequest.entries().get(0).messageAttributes();
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(), messageAttributes.get("action"));
|
||||
for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) {
|
||||
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(),
|
||||
sendMessageRequest.messageAttributes().get("action"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue