diff --git a/service/pom.xml b/service/pom.xml
index aa1a93ed8..ecac4bbf5 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -145,6 +145,12 @@
3.1.0
+
+ com.google.guava
+ guava
+ 30.1.1-jre
+
+
com.googlecode.libphonenumber
libphonenumber
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java
index e9766e90a..fff3d54ab 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueue.java
@@ -4,6 +4,8 @@
*/
package org.whispersystems.textsecuregcm.sqs;
+import static com.codahale.metrics.MetricRegistry.name;
+
import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentials;
@@ -12,33 +14,33 @@ import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
-import com.amazonaws.services.sqs.model.SendMessageRequest;
+import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
+import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import com.google.common.collect.Iterables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants;
-
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.UUID;
-
-import static com.codahale.metrics.MetricRegistry.name;
+import org.whispersystems.textsecuregcm.util.Pair;
public class DirectoryQueue {
private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class);
- private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
- private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError"));
- private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
- private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage"));
+ private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError"));
+ private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
+ private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
private final List queueUrls;
private final AmazonSQS sqs;
@@ -58,36 +60,54 @@ public class DirectoryQueue {
}
public void refreshRegisteredUser(final Account account) {
- sendMessage(account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete", account.getUuid(), account.getNumber());
+ refreshRegisteredUsers(List.of(account));
+ }
+
+ public void refreshRegisteredUsers(final List accounts) {
+ final List> accountsAndActions = accounts.stream()
+ .map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete"))
+ .collect(Collectors.toList());
+
+ sendUpdateMessages(accountsAndActions);
}
public void deleteAccount(final Account account) {
- sendMessage("delete", account.getUuid(), account.getNumber());
+ sendUpdateMessages(List.of(new Pair<>(account, "delete")));
}
- private void sendMessage(String action, UUID uuid, String number) {
- final Map messageAttributes = new HashMap<>();
- messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number));
- messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString()));
- messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action));
-
+ private void sendUpdateMessages(final List> accountsAndActions) {
for (final String queueUrl : queueUrls) {
- final SendMessageRequest sendMessageRequest = new SendMessageRequest()
- .withQueueUrl(queueUrl)
+ for (final List> partition : Iterables.partition(accountsAndActions, 10)) {
+ final List entries = partition.stream().map(pair -> {
+ final Account account = pair.first();
+ final String action = pair.second();
+
+ return new SendMessageBatchRequestEntry()
.withMessageBody("-")
.withMessageDeduplicationId(UUID.randomUUID().toString())
- .withMessageGroupId(number)
- .withMessageAttributes(messageAttributes);
- try (final Timer.Context ignored = sendMessageTimer.time()) {
- sqs.sendMessage(sendMessageRequest);
- } catch (AmazonServiceException ex) {
- serviceErrorMeter.mark();
- logger.warn("sqs service error: ", ex);
- } catch (AmazonClientException ex) {
- clientErrorMeter.mark();
- logger.warn("sqs client error: ", ex);
- } catch (Throwable t) {
- logger.warn("sqs unexpected error: ", t);
+ .withMessageGroupId(account.getNumber())
+ .withMessageAttributes(Map.of(
+ "id", new MessageAttributeValue().withDataType("String").withStringValue(account.getNumber()),
+ "uuid", new MessageAttributeValue().withDataType("String").withStringValue(account.getUuid().toString()),
+ "action", new MessageAttributeValue().withDataType("String").withStringValue(action)
+ ));
+ }).collect(Collectors.toList());
+
+ final SendMessageBatchRequest sendMessageBatchRequest = new SendMessageBatchRequest()
+ .withQueueUrl(queueUrl)
+ .withEntries(entries);
+
+ try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
+ sqs.sendMessageBatch(sendMessageBatchRequest);
+ } catch (AmazonServiceException ex) {
+ serviceErrorMeter.mark();
+ logger.warn("sqs service error: ", ex);
+ } catch (AmazonClientException ex) {
+ clientErrorMeter.mark();
+ logger.warn("sqs client error: ", ex);
+ } catch (Throwable t) {
+ logger.warn("sqs unexpected error: ", t);
+ }
}
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java
index e797bb088..60b311e83 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java
@@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
@@ -41,6 +42,8 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@Override
protected void onCrawlChunk(Optional fromUuid, List chunkAccounts) {
+ final List directoryUpdateAccounts = new ArrayList<>();
+
for (Account account : chunkAccounts) {
boolean update = false;
@@ -74,8 +77,12 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
if (update) {
accountsManager.update(account);
- directoryQueue.refreshRegisteredUser(account);
+ directoryUpdateAccounts.add(account);
}
}
+
+ if (!directoryUpdateAccounts.isEmpty()) {
+ directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts);
+ }
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java
index cf1575313..3bf270a7c 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/sqs/DirectoryQueueTest.java
@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.sqs;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
+import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
@@ -43,13 +44,50 @@ public class DirectoryQueueTest {
directoryQueue.refreshRegisteredUser(account);
- final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
- verify(sqs).sendMessage(requestCaptor.capture());
+ final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(sqs).sendMessageBatch(requestCaptor.capture());
- final Map messageAttributes = requestCaptor.getValue().getMessageAttributes();
+ assertEquals(1, requestCaptor.getValue().getEntries().size());
+
+ final Map messageAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action"));
}
+ @Test
+ public void testRefreshBatch() {
+ final AmazonSQS sqs = mock(AmazonSQS.class);
+ final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
+
+ final Account discoverableAccount = mock(Account.class);
+ when(discoverableAccount.getNumber()).thenReturn("+18005556543");
+ when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
+ when(discoverableAccount.isEnabled()).thenReturn(true);
+ when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true);
+
+ final Account undiscoverableAccount = mock(Account.class);
+ when(undiscoverableAccount.getNumber()).thenReturn("+18005550987");
+ when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
+ when(undiscoverableAccount.isEnabled()).thenReturn(true);
+ when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false);
+
+ directoryQueue.refreshRegisteredUsers(List.of(discoverableAccount, undiscoverableAccount));
+
+ final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(sqs).sendMessageBatch(requestCaptor.capture());
+
+ assertEquals(2, requestCaptor.getValue().getEntries().size());
+
+ final Map discoverableAccountAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getNumber()), discoverableAccountAttributes.get("id"));
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getUuid().toString()), discoverableAccountAttributes.get("uuid"));
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), discoverableAccountAttributes.get("action"));
+
+ final Map undiscoverableAccountAttributes = requestCaptor.getValue().getEntries().get(1).getMessageAttributes();
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getNumber()), undiscoverableAccountAttributes.get("id"));
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getUuid().toString()), undiscoverableAccountAttributes.get("uuid"));
+ assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("delete"), undiscoverableAccountAttributes.get("action"));
+ }
+
@Test
public void testSendMessageMultipleQueues() {
final AmazonSQS sqs = mock(AmazonSQS.class);
@@ -63,11 +101,13 @@ public class DirectoryQueueTest {
directoryQueue.refreshRegisteredUser(account);
- final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
- verify(sqs, times(2)).sendMessage(requestCaptor.capture());
+ final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture());
- for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) {
- final Map messageAttributes = sendMessageRequest.getMessageAttributes();
+ for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) {
+ assertEquals(1, requestCaptor.getValue().getEntries().size());
+
+ final Map messageAttributes = sendMessageBatchRequest.getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action"));
}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java
index 03e5f5438..1a448d66e 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java
@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before;
import org.junit.Test;
+import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
@@ -22,6 +23,7 @@ import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
public class PushFeedbackProcessorTest {
@@ -131,8 +133,10 @@ public class PushFeedbackProcessorTest {
verify(accountsManager).update(eq(stillActiveAccount));
- verify(directoryQueue).refreshRegisteredUser(undiscoverableAccount);
- verify(directoryQueue).refreshRegisteredUser(uninstalledAccount);
+ final ArgumentCaptor> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class);
+ verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture());
+
+ assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount)));
}