Change from using parallel streams to using an ExecutorService

This commit is contained in:
Ehren Kret 2021-08-12 11:37:33 -05:00
parent f7f870fe62
commit d13741fbd5
4 changed files with 63 additions and 43 deletions

View File

@ -417,6 +417,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ExecutorService backupServiceExecutor = environment.lifecycle().executorService(name(getClass(), "backupService-%d")).maxThreads(1).minThreads(1).build(); ExecutorService backupServiceExecutor = environment.lifecycle().executorService(name(getClass(), "backupService-%d")).maxThreads(1).minThreads(1).build();
ExecutorService storageServiceExecutor = environment.lifecycle().executorService(name(getClass(), "storageService-%d")).maxThreads(1).minThreads(1).build(); ExecutorService storageServiceExecutor = environment.lifecycle().executorService(name(getClass(), "storageService-%d")).maxThreads(1).minThreads(1).build();
ExecutorService donationExecutor = environment.lifecycle().executorService(name(getClass(), "donation-%d")).maxThreads(1).minThreads(1).build(); ExecutorService donationExecutor = environment.lifecycle().executorService(name(getClass(), "donation-%d")).maxThreads(1).minThreads(1).build();
ExecutorService multiRecipientMessageExecutor = environment.lifecycle().executorService(name(getClass(), "multiRecipientMessage-%d")).maxThreads(64).build();
ExternalServiceCredentialGenerator directoryCredentialsGenerator = new ExternalServiceCredentialGenerator(config.getDirectoryConfiguration().getDirectoryClientConfiguration().getUserAuthenticationTokenSharedSecret(), ExternalServiceCredentialGenerator directoryCredentialsGenerator = new ExternalServiceCredentialGenerator(config.getDirectoryConfiguration().getDirectoryClientConfiguration().getUserAuthenticationTokenSharedSecret(),
config.getDirectoryConfiguration().getDirectoryClientConfiguration().getUserAuthenticationTokenUserIdSecret(), config.getDirectoryConfiguration().getDirectoryClientConfiguration().getUserAuthenticationTokenUserIdSecret(),
@ -607,7 +608,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keysDynamoDb, rateLimiters, config.getMaxDevices()), new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keysDynamoDb, rateLimiters, config.getMaxDevices()),
new DirectoryController(directoryCredentialsGenerator), new DirectoryController(directoryCredentialsGenerator),
new DonationController(donationExecutor, config.getDonationConfiguration()), new DonationController(donationExecutor, config.getDonationConfiguration()),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, declinedMessageReceiptExecutor), new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, declinedMessageReceiptExecutor, multiRecipientMessageExecutor),
new PaymentsController(currencyManager, paymentsCredentialsGenerator), new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, isZkEnabled), new ProfileController(rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, isZkEnabled),
new ProvisioningController(rateLimiters, provisioningManager), new ProvisioningController(rateLimiters, provisioningManager),

View File

@ -33,16 +33,20 @@ import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
@ -132,6 +136,7 @@ public class MessageController {
private final RateLimitChallengeManager rateLimitChallengeManager; private final RateLimitChallengeManager rateLimitChallengeManager;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ScheduledExecutorService receiptExecutorService; private final ScheduledExecutorService receiptExecutorService;
private final ExecutorService multiRecipientMessageExecutor;
private final Random random = new Random(); private final Random random = new Random();
@ -153,7 +158,8 @@ public class MessageController {
private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes(); private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
public MessageController(RateLimiters rateLimiters, public MessageController(
RateLimiters rateLimiters,
MessageSender messageSender, MessageSender messageSender,
ReceiptSender receiptSender, ReceiptSender receiptSender,
AccountsManager accountsManager, AccountsManager accountsManager,
@ -164,19 +170,20 @@ public class MessageController {
RateLimitChallengeManager rateLimitChallengeManager, RateLimitChallengeManager rateLimitChallengeManager,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
FaultTolerantRedisCluster metricsCluster, FaultTolerantRedisCluster metricsCluster,
ScheduledExecutorService receiptExecutorService) ScheduledExecutorService receiptExecutorService,
{ @Nonnull ExecutorService multiRecipientMessageExecutor) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.messageSender = messageSender; this.messageSender = messageSender;
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.unsealedSenderRateLimiter = unsealedSenderRateLimiter; this.unsealedSenderRateLimiter = unsealedSenderRateLimiter;
this.apnFallbackManager = apnFallbackManager; this.apnFallbackManager = apnFallbackManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitChallengeManager = rateLimitChallengeManager; this.rateLimitChallengeManager = rateLimitChallengeManager;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.receiptExecutorService = receiptExecutorService; this.receiptExecutorService = receiptExecutorService;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
try { try {
recordInternationalUnsealedSenderMetricsScript = ClusterLuaScript.fromResource(metricsCluster, "lua/record_international_unsealed_sender_metrics.lua", ScriptOutputType.MULTI); recordInternationalUnsealedSenderMetricsScript = ClusterLuaScript.fromResource(metricsCluster, "lua/record_international_unsealed_sender_metrics.lua", ScriptOutputType.MULTI);
@ -420,19 +427,27 @@ public class MessageController {
Tag.of(SENDER_TYPE_TAG_NAME, "unidentified")); Tag.of(SENDER_TYPE_TAG_NAME, "unidentified"));
List<UUID> uuids404 = Collections.synchronizedList(new ArrayList<>()); List<UUID> uuids404 = Collections.synchronizedList(new ArrayList<>());
final Counter counter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags); final Counter counter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags);
Arrays.stream(multiRecipientMessage.getRecipients()).parallel().forEach(recipient -> { try {
Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.getRecipients())
.map(recipient -> (Callable<Void>) () -> {
Account destinationAccount = uuidToAccountMap.get(recipient.getUuid());
// we asserted this must exist in validateCompleteDeviceList // we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow();
counter.increment(); counter.increment();
try { try {
sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient,
multiRecipientMessage.getCommonPayload()); multiRecipientMessage.getCommonPayload());
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
uuids404.add(destinationAccount.getUuid()); uuids404.add(destinationAccount.getUuid());
} }
}); return null;
})
.collect(Collectors.toList()));
} catch (InterruptedException e) {
logger.error("interrupted while delivering multi-recipient messages", e);
return Response.serverError().entity("interrupted during delivery").build();
}
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
} }

View File

@ -9,6 +9,7 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -32,7 +33,8 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
public void setUp() throws Exception { public void setUp() throws Exception {
super.setUp(); super.setUp();
messageController = new MessageController(mock(RateLimiters.class), messageController = new MessageController(
mock(RateLimiters.class),
mock(MessageSender.class), mock(MessageSender.class),
mock(ReceiptSender.class), mock(ReceiptSender.class),
mock(AccountsManager.class), mock(AccountsManager.class),
@ -43,7 +45,8 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
mock(RateLimitChallengeManager.class), mock(RateLimitChallengeManager.class),
mock(ReportMessageManager.class), mock(ReportMessageManager.class),
getRedisCluster(), getRedisCluster(),
mock(ScheduledExecutorService.class)); mock(ScheduledExecutorService.class),
mock(ExecutorService.class));
} }
@Test @Test

View File

@ -47,6 +47,7 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -86,7 +87,6 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge; import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -125,20 +125,20 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class); private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
private static final MessageSender messageSender = mock(MessageSender.class); private static final MessageSender messageSender = mock(MessageSender.class);
private static final ReceiptSender receiptSender = mock(ReceiptSender.class); private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final MessagesManager messagesManager = mock(MessagesManager.class); private static final MessagesManager messagesManager = mock(MessagesManager.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class); private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class); private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands); private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands);
private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class); private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
@ -151,7 +151,8 @@ class MessageControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager,
rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor)) rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor,
multiRecipientMessageExecutor))
.build(); .build();
@BeforeEach @BeforeEach