Send directory updates in batches.

This commit is contained in:
Jon Chambers 2021-05-03 12:33:31 -04:00 committed by Jon Chambers
parent 30c9968928
commit 8fdbcbef44
5 changed files with 121 additions and 44 deletions

View File

@ -145,6 +145,12 @@
<version>3.1.0</version> <version>3.1.0</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>30.1.1-jre</version>
</dependency>
<dependency> <dependency>
<groupId>com.googlecode.libphonenumber</groupId> <groupId>com.googlecode.libphonenumber</groupId>
<artifactId>libphonenumber</artifactId> <artifactId>libphonenumber</artifactId>

View File

@ -4,6 +4,8 @@
*/ */
package org.whispersystems.textsecuregcm.sqs; package org.whispersystems.textsecuregcm.sqs;
import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.AmazonClientException; import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException; import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentials;
@ -12,33 +14,33 @@ import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder; import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.MessageAttributeValue; import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import com.google.common.collect.Iterables;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration; import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class DirectoryQueue { public class DirectoryQueue {
private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class); private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError")); private final Meter serviceErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "serviceError"));
private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError")); private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
private final Timer sendMessageTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessage")); private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
private final List<String> queueUrls; private final List<String> queueUrls;
private final AmazonSQS sqs; private final AmazonSQS sqs;
@ -58,36 +60,54 @@ public class DirectoryQueue {
} }
public void refreshRegisteredUser(final Account account) { public void refreshRegisteredUser(final Account account) {
sendMessage(account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete", account.getUuid(), account.getNumber()); refreshRegisteredUsers(List.of(account));
}
public void refreshRegisteredUsers(final List<Account> accounts) {
final List<Pair<Account, String>> accountsAndActions = accounts.stream()
.map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete"))
.collect(Collectors.toList());
sendUpdateMessages(accountsAndActions);
} }
public void deleteAccount(final Account account) { public void deleteAccount(final Account account) {
sendMessage("delete", account.getUuid(), account.getNumber()); sendUpdateMessages(List.of(new Pair<>(account, "delete")));
} }
private void sendMessage(String action, UUID uuid, String number) { private void sendUpdateMessages(final List<Pair<Account, String>> accountsAndActions) {
final Map<String, MessageAttributeValue> messageAttributes = new HashMap<>();
messageAttributes.put("id", new MessageAttributeValue().withDataType("String").withStringValue(number));
messageAttributes.put("uuid", new MessageAttributeValue().withDataType("String").withStringValue(uuid.toString()));
messageAttributes.put("action", new MessageAttributeValue().withDataType("String").withStringValue(action));
for (final String queueUrl : queueUrls) { for (final String queueUrl : queueUrls) {
final SendMessageRequest sendMessageRequest = new SendMessageRequest() for (final List<Pair<Account, String>> partition : Iterables.partition(accountsAndActions, 10)) {
.withQueueUrl(queueUrl) final List<SendMessageBatchRequestEntry> entries = partition.stream().map(pair -> {
final Account account = pair.first();
final String action = pair.second();
return new SendMessageBatchRequestEntry()
.withMessageBody("-") .withMessageBody("-")
.withMessageDeduplicationId(UUID.randomUUID().toString()) .withMessageDeduplicationId(UUID.randomUUID().toString())
.withMessageGroupId(number) .withMessageGroupId(account.getNumber())
.withMessageAttributes(messageAttributes); .withMessageAttributes(Map.of(
try (final Timer.Context ignored = sendMessageTimer.time()) { "id", new MessageAttributeValue().withDataType("String").withStringValue(account.getNumber()),
sqs.sendMessage(sendMessageRequest); "uuid", new MessageAttributeValue().withDataType("String").withStringValue(account.getUuid().toString()),
} catch (AmazonServiceException ex) { "action", new MessageAttributeValue().withDataType("String").withStringValue(action)
serviceErrorMeter.mark(); ));
logger.warn("sqs service error: ", ex); }).collect(Collectors.toList());
} catch (AmazonClientException ex) {
clientErrorMeter.mark(); final SendMessageBatchRequest sendMessageBatchRequest = new SendMessageBatchRequest()
logger.warn("sqs client error: ", ex); .withQueueUrl(queueUrl)
} catch (Throwable t) { .withEntries(entries);
logger.warn("sqs unexpected error: ", t);
try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
sqs.sendMessageBatch(sendMessageBatchRequest);
} catch (AmazonServiceException ex) {
serviceErrorMeter.mark();
logger.warn("sqs service error: ", ex);
} catch (AmazonClientException ex) {
clientErrorMeter.mark();
logger.warn("sqs client error: ", ex);
} catch (Throwable t) {
logger.warn("sqs unexpected error: ", t);
}
} }
} }
} }

View File

@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -41,6 +42,8 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@Override @Override
protected void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) { protected void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) {
final List<Account> directoryUpdateAccounts = new ArrayList<>();
for (Account account : chunkAccounts) { for (Account account : chunkAccounts) {
boolean update = false; boolean update = false;
@ -74,8 +77,12 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
if (update) { if (update) {
accountsManager.update(account); accountsManager.update(account);
directoryQueue.refreshRegisteredUser(account); directoryUpdateAccounts.add(account);
} }
} }
if (!directoryUpdateAccounts.isEmpty()) {
directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts);
}
} }
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.sqs;
import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.model.MessageAttributeValue; import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageRequest; import com.amazonaws.services.sqs.model.SendMessageRequest;
import junitparams.JUnitParamsRunner; import junitparams.JUnitParamsRunner;
import junitparams.Parameters; import junitparams.Parameters;
@ -43,13 +44,50 @@ public class DirectoryQueueTest {
directoryQueue.refreshRegisteredUser(account); directoryQueue.refreshRegisteredUser(account);
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessage(requestCaptor.capture()); verify(sqs).sendMessageBatch(requestCaptor.capture());
final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().getMessageAttributes(); assertEquals(1, requestCaptor.getValue().getEntries().size());
final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action")); assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action"));
} }
@Test
public void testRefreshBatch() {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
final Account discoverableAccount = mock(Account.class);
when(discoverableAccount.getNumber()).thenReturn("+18005556543");
when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
when(discoverableAccount.isEnabled()).thenReturn(true);
when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true);
final Account undiscoverableAccount = mock(Account.class);
when(undiscoverableAccount.getNumber()).thenReturn("+18005550987");
when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
when(undiscoverableAccount.isEnabled()).thenReturn(true);
when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false);
directoryQueue.refreshRegisteredUsers(List.of(discoverableAccount, undiscoverableAccount));
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessageBatch(requestCaptor.capture());
assertEquals(2, requestCaptor.getValue().getEntries().size());
final Map<String, MessageAttributeValue> discoverableAccountAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getNumber()), discoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getUuid().toString()), discoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), discoverableAccountAttributes.get("action"));
final Map<String, MessageAttributeValue> undiscoverableAccountAttributes = requestCaptor.getValue().getEntries().get(1).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getNumber()), undiscoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getUuid().toString()), undiscoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("delete"), undiscoverableAccountAttributes.get("action"));
}
@Test @Test
public void testSendMessageMultipleQueues() { public void testSendMessageMultipleQueues() {
final AmazonSQS sqs = mock(AmazonSQS.class); final AmazonSQS sqs = mock(AmazonSQS.class);
@ -63,11 +101,13 @@ public class DirectoryQueueTest {
directoryQueue.refreshRegisteredUser(account); directoryQueue.refreshRegisteredUser(account);
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs, times(2)).sendMessage(requestCaptor.capture()); verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture());
for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) { for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) {
final Map<String, MessageAttributeValue> messageAttributes = sendMessageRequest.getMessageAttributes(); assertEquals(1, requestCaptor.getValue().getEntries().size());
final Map<String, MessageAttributeValue> messageAttributes = sendMessageBatchRequest.getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action")); assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action"));
} }
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
@ -22,6 +23,7 @@ import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
public class PushFeedbackProcessorTest { public class PushFeedbackProcessorTest {
@ -131,8 +133,10 @@ public class PushFeedbackProcessorTest {
verify(accountsManager).update(eq(stillActiveAccount)); verify(accountsManager).update(eq(stillActiveAccount));
verify(directoryQueue).refreshRegisteredUser(undiscoverableAccount); final ArgumentCaptor<List<Account>> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class);
verify(directoryQueue).refreshRegisteredUser(uninstalledAccount); verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture());
assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount)));
} }