From d13741fbd52498635004ebfb003d4396267a5da9 Mon Sep 17 00:00:00 2001 From: Ehren Kret Date: Thu, 12 Aug 2021 11:37:33 -0500 Subject: [PATCH] Change from using parallel streams to using an ExecutorService --- .../textsecuregcm/WhisperServerService.java | 3 +- .../controllers/MessageController.java | 65 ++++++++++++------- .../MessageControllerMetricsTest.java | 7 +- .../controllers/MessageControllerTest.java | 31 ++++----- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index b47e2d5ca..efe4c8429 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -417,6 +417,7 @@ public class WhisperServerService extends Application uuids404 = Collections.synchronizedList(new ArrayList<>()); final Counter counter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags); - Arrays.stream(multiRecipientMessage.getRecipients()).parallel().forEach(recipient -> { - Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); + try { + multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.getRecipients()) + .map(recipient -> (Callable) () -> { + Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); - // we asserted this must exist in validateCompleteDeviceList - Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); - counter.increment(); - try { - sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, - multiRecipientMessage.getCommonPayload()); - } catch (NoSuchUserException e) { - uuids404.add(destinationAccount.getUuid()); - } - }); + // we asserted this must exist in validateCompleteDeviceList + Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); + counter.increment(); + try { + sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, + multiRecipientMessage.getCommonPayload()); + } catch (NoSuchUserException e) { + uuids404.add(destinationAccount.getUuid()); + } + return null; + }) + .collect(Collectors.toList())); + } catch (InterruptedException e) { + logger.error("interrupted while delivering multi-recipient messages", e); + return Response.serverError().entity("interrupted during delivery").build(); + } return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerMetricsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerMetricsTest.java index 65651a7e8..56050a423 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerMetricsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerMetricsTest.java @@ -9,6 +9,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; +import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import org.junit.Before; import org.junit.Test; @@ -32,7 +33,8 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest { public void setUp() throws Exception { super.setUp(); - messageController = new MessageController(mock(RateLimiters.class), + messageController = new MessageController( + mock(RateLimiters.class), mock(MessageSender.class), mock(ReceiptSender.class), mock(AccountsManager.class), @@ -43,7 +45,8 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest { mock(RateLimitChallengeManager.class), mock(ReportMessageManager.class), getRedisCluster(), - mock(ScheduledExecutorService.class)); + mock(ScheduledExecutorService.class), + mock(ExecutorService.class)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index 66e3568c7..376b50963 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -47,6 +47,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -86,7 +87,6 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.RateLimitChallenge; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.StaleDevices; -import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -125,20 +125,20 @@ class MessageControllerTest { @SuppressWarnings("unchecked") private static final RedisAdvancedClusterCommands redisCommands = mock(RedisAdvancedClusterCommands.class); - private static final MessageSender messageSender = mock(MessageSender.class); - private static final ReceiptSender receiptSender = mock(ReceiptSender.class); - private static final AccountsManager accountsManager = mock(AccountsManager.class); - private static final MessagesManager messagesManager = mock(MessagesManager.class); - private static final RateLimiters rateLimiters = mock(RateLimiters.class); - private static final RateLimiter rateLimiter = mock(RateLimiter.class); - private static final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class); - private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class); - private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); + private static final MessageSender messageSender = mock(MessageSender.class); + private static final ReceiptSender receiptSender = mock(ReceiptSender.class); + private static final AccountsManager accountsManager = mock(AccountsManager.class); + private static final MessagesManager messagesManager = mock(MessagesManager.class); + private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final RateLimiter rateLimiter = mock(RateLimiter.class); + private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class); + private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); - private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); - private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); - private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands); - private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class); + private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); + private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); + private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands); + private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class); + private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); private final ObjectMapper mapper = new ObjectMapper(); @@ -151,7 +151,8 @@ class MessageControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, - rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor)) + rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor, + multiRecipientMessageExecutor)) .build(); @BeforeEach