Publish directory updates to multiple SQS queues.

This commit is contained in:
Jon Chambers 2021-01-08 16:23:43 -05:00 committed by Jon Chambers
parent 6af7bfb536
commit 9ee6419bc0
3 changed files with 58 additions and 36 deletions

View File

@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty; import org.hibernate.validator.constraints.NotEmpty;
import java.util.List;
public class SqsConfiguration { public class SqsConfiguration {
@NotEmpty @NotEmpty
@JsonProperty @JsonProperty
@ -18,7 +20,7 @@ public class SqsConfiguration {
@NotEmpty @NotEmpty
@JsonProperty @JsonProperty
private String queueUrl; private List<String> queueUrls;
@NotEmpty @NotEmpty
@JsonProperty @JsonProperty
@ -32,13 +34,11 @@ public class SqsConfiguration {
return accessSecret; return accessSecret;
} }
public String getQueueUrl() { public List<String> getQueueUrls() {
return queueUrl; return queueUrls;
} }
public String getRegion() { public String getRegion() {
return region; return region;
} }
} }

View File

@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
@ -39,21 +40,21 @@ public class DirectoryQueue {
private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage")); private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage"));
private final String queueUrl; private final List<String> queueUrls;
private final AmazonSQS sqs; private final AmazonSQS sqs;
public DirectoryQueue(SqsConfiguration sqsConfig) { public DirectoryQueue(SqsConfiguration sqsConfig) {
final AWSCredentials credentials = new BasicAWSCredentials(sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()); final AWSCredentials credentials = new BasicAWSCredentials(sqsConfig.getAccessKey(), sqsConfig.getAccessSecret());
final AWSStaticCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials); final AWSStaticCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials);
this.queueUrl = sqsConfig.getQueueUrl(); this.queueUrls = sqsConfig.getQueueUrls();
this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build(); this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build();
} }
@VisibleForTesting @VisibleForTesting
DirectoryQueue(final String queueUrl, final AmazonSQS sqs) { DirectoryQueue(final List<String> queueUrls, final AmazonSQS sqs) {
this.queueUrl = queueUrl; this.queueUrls = queueUrls;
this.sqs = sqs; this.sqs = sqs;
} }
public void refreshRegisteredUser(final Account account) { public void refreshRegisteredUser(final Account account) {
@ -69,22 +70,25 @@ public class DirectoryQueue {
messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number)); messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number));
messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString())); messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString()));
messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action)); messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action));
SendMessageRequest sendMessageRequest = new SendMessageRequest()
.withQueueUrl(queueUrl) for (final String queueUrl : queueUrls) {
.withMessageBody("-") final SendMessageRequest sendMessageRequest = new SendMessageRequest()
.withMessageDeduplicationId(UUID.randomUUID().toString()) .withQueueUrl(queueUrl)
.withMessageGroupId(number) .withMessageBody("-")
.withMessageAttributes(messageAttributes); .withMessageDeduplicationId(UUID.randomUUID().toString())
try (final Timer.Context ignored = sendMessageTimer.time()) { .withMessageGroupId(number)
sqs.sendMessage(sendMessageRequest); .withMessageAttributes(messageAttributes);
} catch (AmazonServiceException ex) { try (final Timer.Context ignored = sendMessageTimer.time()) {
serviceErrorMeter.mark(); sqs.sendMessage(sendMessageRequest);
logger.warn("sqs service error: ", ex); } catch (AmazonServiceException ex) {
} catch (AmazonClientException ex) { serviceErrorMeter.mark();
clientErrorMeter.mark(); logger.warn("sqs service error: ", ex);
logger.warn("sqs client error: ", ex); } catch (AmazonClientException ex) {
} catch (Throwable t) { clientErrorMeter.mark();
logger.warn("sqs unexpected error: ", t); logger.warn("sqs client error: ", ex);
} catch (Throwable t) {
logger.warn("sqs unexpected error: ", t);
}
} }
} }

View File

@ -16,29 +16,25 @@ import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@RunWith(JUnitParamsRunner.class) @RunWith(JUnitParamsRunner.class)
public class DirectoryQueueTest { public class DirectoryQueueTest {
private AmazonSQS sqs;
private DirectoryQueue directoryQueue;
@Before
public void setUp() {
sqs = mock(AmazonSQS.class);
directoryQueue = new DirectoryQueue("sqs://test", sqs);
}
@Test @Test
@Parameters(method = "argumentsForTestRefreshRegisteredUser") @Parameters(method = "argumentsForTestRefreshRegisteredUser")
public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) { public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543"); when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -54,6 +50,28 @@ public class DirectoryQueueTest {
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action")); assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action"));
} }
@Test
public void testSendMessageMultipleQueues() {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs);
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(true);
when(account.isDiscoverableByPhoneNumber()).thenReturn(true);
directoryQueue.refreshRegisteredUser(account);
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
verify(sqs, times(2)).sendMessage(requestCaptor.capture());
for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) {
final Map<String, MessageAttributeValue> messageAttributes = sendMessageRequest.getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action"));
}
}
@SuppressWarnings("unused") @SuppressWarnings("unused")
private Object argumentsForTestRefreshRegisteredUser() { private Object argumentsForTestRefreshRegisteredUser() {
return new Object[] { return new Object[] {