diff --git a/service/pom.xml b/service/pom.xml index aa1a93ed8..ecac4bbf5 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -145,6 +145,12 @@ 3.1.0 + + com.google.guava + guava + 30.1.1-jre + + com.googlecode.libphonenumber libphonenumber diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java index e9766e90a..fff3d54ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java @@ -4,6 +4,8 @@ */ package org.whispersystems.textsecuregcm.sqs; +import static com.codahale.metrics.MetricRegistry.name; + import com.amazonaws.AmazonClientException; import com.amazonaws.AmazonServiceException; import com.amazonaws.auth.AWSCredentials; @@ -12,33 +14,33 @@ import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQSClientBuilder; import com.amazonaws.services.sqs.model.MessageAttributeValue; -import com.amazonaws.services.sqs.model.SendMessageRequest; +import com.amazonaws.services.sqs.model.SendMessageBatchRequest; +import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import com.google.common.collect.Iterables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.SqsConfiguration; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Constants; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; - -import static com.codahale.metrics.MetricRegistry.name; +import org.whispersystems.textsecuregcm.util.Pair; public class DirectoryQueue { private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class); - private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError")); - private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); - private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage")); + private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError")); + private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); + private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch")); private final List queueUrls; private final AmazonSQS sqs; @@ -58,36 +60,54 @@ public class DirectoryQueue { } public void refreshRegisteredUser(final Account account) { - sendMessage(account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete", account.getUuid(), account.getNumber()); + refreshRegisteredUsers(List.of(account)); + } + + public void refreshRegisteredUsers(final List accounts) { + final List> accountsAndActions = accounts.stream() + .map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete")) + .collect(Collectors.toList()); + + sendUpdateMessages(accountsAndActions); } public void deleteAccount(final Account account) { - sendMessage("delete", account.getUuid(), account.getNumber()); + sendUpdateMessages(List.of(new Pair<>(account, "delete"))); } - private void sendMessage(String action, UUID uuid, String number) { - final Map messageAttributes = new HashMap<>(); - messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number)); - messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString())); - messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action)); - + private void sendUpdateMessages(final List> accountsAndActions) { for (final String queueUrl : queueUrls) { - final SendMessageRequest sendMessageRequest = new SendMessageRequest() - .withQueueUrl(queueUrl) + for (final List> partition : Iterables.partition(accountsAndActions, 10)) { + final List entries = partition.stream().map(pair -> { + final Account account = pair.first(); + final String action = pair.second(); + + return new SendMessageBatchRequestEntry() .withMessageBody("-") .withMessageDeduplicationId(UUID.randomUUID().toString()) - .withMessageGroupId(number) - .withMessageAttributes(messageAttributes); - try (final Timer.Context ignored = sendMessageTimer.time()) { - sqs.sendMessage(sendMessageRequest); - } catch (AmazonServiceException ex) { - serviceErrorMeter.mark(); - logger.warn("sqs service error: ", ex); - } catch (AmazonClientException ex) { - clientErrorMeter.mark(); - logger.warn("sqs client error: ", ex); - } catch (Throwable t) { - logger.warn("sqs unexpected error: ", t); + .withMessageGroupId(account.getNumber()) + .withMessageAttributes(Map.of( + "id", new MessageAttributeValue().withDataType("String").withStringValue(account.getNumber()), + "uuid", new MessageAttributeValue().withDataType("String").withStringValue(account.getUuid().toString()), + "action", new MessageAttributeValue().withDataType("String").withStringValue(action) + )); + }).collect(Collectors.toList()); + + final SendMessageBatchRequest sendMessageBatchRequest = new SendMessageBatchRequest() + .withQueueUrl(queueUrl) + .withEntries(entries); + + try (final Timer.Context ignored = sendMessageBatchTimer.time()) { + sqs.sendMessageBatch(sendMessageBatchRequest); + } catch (AmazonServiceException ex) { + serviceErrorMeter.mark(); + logger.warn("sqs service error: ", ex); + } catch (AmazonClientException ex) { + clientErrorMeter.mark(); + logger.warn("sqs client error: ", ex); + } catch (Throwable t) { + logger.warn("sqs unexpected error: ", t); + } } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java index e797bb088..60b311e83 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java @@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -41,6 +42,8 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener { @Override protected void onCrawlChunk(Optional fromUuid, List chunkAccounts) { + final List directoryUpdateAccounts = new ArrayList<>(); + for (Account account : chunkAccounts) { boolean update = false; @@ -74,8 +77,12 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener { if (update) { accountsManager.update(account); - directoryQueue.refreshRegisteredUser(account); + directoryUpdateAccounts.add(account); } } + + if (!directoryUpdateAccounts.isEmpty()) { + directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts); + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java index cf1575313..3bf270a7c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.sqs; import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.model.MessageAttributeValue; +import com.amazonaws.services.sqs.model.SendMessageBatchRequest; import com.amazonaws.services.sqs.model.SendMessageRequest; import junitparams.JUnitParamsRunner; import junitparams.Parameters; @@ -43,13 +44,50 @@ public class DirectoryQueueTest { directoryQueue.refreshRegisteredUser(account); - final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); - verify(sqs).sendMessage(requestCaptor.capture()); + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(sqs).sendMessageBatch(requestCaptor.capture()); - final Map messageAttributes = requestCaptor.getValue().getMessageAttributes(); + assertEquals(1, requestCaptor.getValue().getEntries().size()); + + final Map messageAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes(); assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action")); } + @Test + public void testRefreshBatch() { + final AmazonSQS sqs = mock(AmazonSQS.class); + final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs); + + final Account discoverableAccount = mock(Account.class); + when(discoverableAccount.getNumber()).thenReturn("+18005556543"); + when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID()); + when(discoverableAccount.isEnabled()).thenReturn(true); + when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true); + + final Account undiscoverableAccount = mock(Account.class); + when(undiscoverableAccount.getNumber()).thenReturn("+18005550987"); + when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID()); + when(undiscoverableAccount.isEnabled()).thenReturn(true); + when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false); + + directoryQueue.refreshRegisteredUsers(List.of(discoverableAccount, undiscoverableAccount)); + + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(sqs).sendMessageBatch(requestCaptor.capture()); + + assertEquals(2, requestCaptor.getValue().getEntries().size()); + + final Map discoverableAccountAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes(); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getNumber()), discoverableAccountAttributes.get("id")); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getUuid().toString()), discoverableAccountAttributes.get("uuid")); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), discoverableAccountAttributes.get("action")); + + final Map undiscoverableAccountAttributes = requestCaptor.getValue().getEntries().get(1).getMessageAttributes(); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getNumber()), undiscoverableAccountAttributes.get("id")); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getUuid().toString()), undiscoverableAccountAttributes.get("uuid")); + assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("delete"), undiscoverableAccountAttributes.get("action")); + } + @Test public void testSendMessageMultipleQueues() { final AmazonSQS sqs = mock(AmazonSQS.class); @@ -63,11 +101,13 @@ public class DirectoryQueueTest { directoryQueue.refreshRegisteredUser(account); - final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); - verify(sqs, times(2)).sendMessage(requestCaptor.capture()); + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture()); - for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) { - final Map messageAttributes = sendMessageRequest.getMessageAttributes(); + for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) { + assertEquals(1, requestCaptor.getValue().getEntries().size()); + + final Map messageAttributes = sendMessageBatchRequest.getEntries().get(0).getMessageAttributes(); assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action")); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java index 03e5f5438..1a448d66e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.storage; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; @@ -22,6 +23,7 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; public class PushFeedbackProcessorTest { @@ -131,8 +133,10 @@ public class PushFeedbackProcessorTest { verify(accountsManager).update(eq(stillActiveAccount)); - verify(directoryQueue).refreshRegisteredUser(undiscoverableAccount); - verify(directoryQueue).refreshRegisteredUser(uninstalledAccount); + final ArgumentCaptor> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class); + verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture()); + + assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount))); }