diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SqsConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SqsConfiguration.java index b46908111..47506a052 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SqsConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/SqsConfiguration.java @@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.configuration; import com.fasterxml.jackson.annotation.JsonProperty; import org.hibernate.validator.constraints.NotEmpty; +import java.util.List; + public class SqsConfiguration { @NotEmpty @JsonProperty @@ -18,7 +20,7 @@ public class SqsConfiguration { @NotEmpty @JsonProperty - private String queueUrl; + private List queueUrls; @NotEmpty @JsonProperty @@ -32,13 +34,11 @@ public class SqsConfiguration { return accessSecret; } - public String getQueueUrl() { - return queueUrl; + public List getQueueUrls() { + return queueUrls; } public String getRegion() { return region; } } - - diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java index 53a056982..e9766e90a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java @@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Constants; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.UUID; @@ -39,21 +40,21 @@ public class DirectoryQueue { private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage")); - private final String queueUrl; + private final List queueUrls; private final AmazonSQS sqs; public DirectoryQueue(SqsConfiguration sqsConfig) { final AWSCredentials credentials = new BasicAWSCredentials(sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()); final AWSStaticCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials); - this.queueUrl = sqsConfig.getQueueUrl(); - this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build(); + this.queueUrls = sqsConfig.getQueueUrls(); + this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build(); } @VisibleForTesting - DirectoryQueue(final String queueUrl, final AmazonSQS sqs) { - this.queueUrl = queueUrl; - this.sqs = sqs; + DirectoryQueue(final List queueUrls, final AmazonSQS sqs) { + this.queueUrls = queueUrls; + this.sqs = sqs; } public void refreshRegisteredUser(final Account account) { @@ -69,22 +70,25 @@ public class DirectoryQueue { messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number)); messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString())); messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action)); - SendMessageRequest sendMessageRequest = new SendMessageRequest() - .withQueueUrl(queueUrl) - .withMessageBody("-") - .withMessageDeduplicationId(UUID.randomUUID().toString()) - .withMessageGroupId(number) - .withMessageAttributes(messageAttributes); - try (final Timer.Context ignored = sendMessageTimer.time()) { - sqs.sendMessage(sendMessageRequest); - } catch (AmazonServiceException ex) { - serviceErrorMeter.mark(); - logger.warn("sqs service error: ", ex); - } catch (AmazonClientException ex) { - clientErrorMeter.mark(); - logger.warn("sqs client error: ", ex); - } catch (Throwable t) { - logger.warn("sqs unexpected error: ", t); + + for (final String queueUrl : queueUrls) { + final SendMessageRequest sendMessageRequest = new SendMessageRequest() + .withQueueUrl(queueUrl) + .withMessageBody("-") + .withMessageDeduplicationId(UUID.randomUUID().toString()) + .withMessageGroupId(number) + .withMessageAttributes(messageAttributes); + try (final Timer.Context ignored = sendMessageTimer.time()) { + sqs.sendMessage(sendMessageRequest); + } catch (AmazonServiceException ex) { + serviceErrorMeter.mark(); + logger.warn("sqs service error: ", ex); + } catch (AmazonClientException ex) { + clientErrorMeter.mark(); + logger.warn("sqs client error: ", ex); + } catch (Throwable t) { + logger.warn("sqs unexpected error: ", t); + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java index 510fe5191..cf1575313 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java @@ -16,29 +16,25 @@ import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.storage.Account; +import java.util.List; import java.util.Map; import java.util.UUID; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @RunWith(JUnitParamsRunner.class) public class DirectoryQueueTest { - private AmazonSQS sqs; - private DirectoryQueue directoryQueue; - - @Before - public void setUp() { - sqs = mock(AmazonSQS.class); - directoryQueue = new DirectoryQueue("sqs://test", sqs); - } - @Test @Parameters(method = "argumentsForTestRefreshRegisteredUser") public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) { + final AmazonSQS sqs = mock(AmazonSQS.class); + final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); + final Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005556543"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -54,6 +50,28 @@ public class DirectoryQueueTest { assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action")); } + @Test + public void testSendMessageMultipleQueues() { + final AmazonSQS sqs = mock(AmazonSQS.class); + final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs); + + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn("+18005556543"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.isEnabled()).thenReturn(true); + when(account.isDiscoverableByPhoneNumber()).thenReturn(true); + + directoryQueue.refreshRegisteredUser(account); + + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(sqs, times(2)).sendMessage(requestCaptor.capture()); + + for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) { + final Map messageAttributes = sendMessageRequest.getMessageAttributes(); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action")); + } + } + @SuppressWarnings("unused") private Object argumentsForTestRefreshRegisteredUser() { return new Object[] {