Clarify roles/responsibilities of components in the message-handling pathway

This commit is contained in:
Jon Chambers 2025-01-31 10:24:50 -05:00 committed by GitHub
parent 282bcf6f34
commit 48ada8e8ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1338 additions and 1199 deletions

View File

@ -431,7 +431,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getRemoteConfig().getTableName()); config.getDynamoDbTables().getRemoteConfig().getTableName());
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient, PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient,
config.getDynamoDbTables().getPushChallenge().getTableName()); config.getDynamoDbTables().getPushChallenge().getTableName());
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getReportMessage().getTableName(), config.getDynamoDbTables().getReportMessage().getTableName(),
config.getReportMessageConfiguration().getReportTtl()); config.getReportMessageConfiguration().getReportTtl());
RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords( RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords(
@ -618,7 +618,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
config.getReportMessageConfiguration().getCounterTtl()); config.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager,
messageDeletionAsyncExecutor); messageDeletionAsyncExecutor, Clock.systemUTC());
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
config.getDynamoDbTables().getDeletedAccountsLock().getTableName()); config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =
@ -1128,7 +1128,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new KeyTransparencyController(keyTransparencyServiceClient), new KeyTransparencyController(keyTransparencyServiceClient),
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender, new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler, accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler,
reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, clientReleaseManager, reportMessageManager, messageDeliveryScheduler, clientReleaseManager,
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor, dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor,
Clock.systemUTC()), Clock.systemUTC()),
new PaymentsController(currencyManager, paymentsCredentialsGenerator), new PaymentsController(currencyManager, paymentsCredentialsGenerator),

View File

@ -7,6 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.util.Collection;
import java.util.function.Predicate;
import java.util.stream.IntStream;
public class UnidentifiedAccessUtil { public class UnidentifiedAccessUtil {
@ -31,4 +34,42 @@ public class UnidentifiedAccessUtil {
.map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey)) .map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey))
.orElse(false); .orElse(false);
} }
/**
* Checks whether an action (e.g. sending a message or retrieving pre-keys) may be taken on the collection of target
* accounts by an actor presenting the given combined unidentified access key.
*
* @param targetAccounts the accounts on which an actor wishes to take an action
* @param combinedUnidentifiedAccessKey the unidentified access key presented by the actor
*
* @return {@code true} if an actor presenting the given unidentified access key has permission to take an action on
* the target accounts or {@code false} otherwise
*/
public static boolean checkUnidentifiedAccess(final Collection<Account> targetAccounts, final byte[] combinedUnidentifiedAccessKey) {
return MessageDigest.isEqual(getCombinedUnidentifiedAccessKey(targetAccounts), combinedUnidentifiedAccessKey);
}
/**
* Calculates a combined unidentified access key for the given collection of accounts.
*
* @param accounts the accounts from which to derive a combined unidentified access key
* @return a combined unidentified access key
*
* @throws IllegalArgumentException if one or more of the given accounts had an unidentified access key with an
* unexpected length
*/
public static byte[] getCombinedUnidentifiedAccessKey(final Collection<Account> accounts) {
return accounts.stream()
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
.map(account ->
account.getUnidentifiedAccessKey()
.filter(b -> b.length == UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH)
.orElseThrow(IllegalArgumentException::new))
.reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH],
(a, b) -> {
final byte[] xor = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
IntStream.range(0, UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH).forEach(i -> xor[i] = (byte) (a[i] ^ b[i]));
return xor;
});
}
} }

View File

@ -9,10 +9,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.dropwizard.util.DataSize; import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -47,33 +45,26 @@ import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status; import jakarta.ws.rs.core.Response.Status;
import java.security.MessageDigest;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
@ -135,6 +126,7 @@ import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples; import reactor.util.function.Tuples;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -142,14 +134,6 @@ import reactor.util.function.Tuples;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages") @io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
public class MessageController { public class MessageController {
private record MultiRecipientDeliveryData(
ServiceIdentifier serviceIdentifier,
Account account,
Recipient recipient,
Map<Byte, Short> deviceIdToRegistrationId) {
}
private static final Logger logger = LoggerFactory.getLogger(MessageController.class); private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
@ -162,7 +146,6 @@ public class MessageController {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final PushNotificationScheduler pushNotificationScheduler; private final PushNotificationScheduler pushNotificationScheduler;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ExecutorService multiRecipientMessageExecutor;
private final Scheduler messageDeliveryScheduler; private final Scheduler messageDeliveryScheduler;
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
@ -229,7 +212,6 @@ public class MessageController {
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler, PushNotificationScheduler pushNotificationScheduler,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
@Nonnull ExecutorService multiRecipientMessageExecutor,
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
final ClientReleaseManager clientReleaseManager, final ClientReleaseManager clientReleaseManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
@ -248,7 +230,6 @@ public class MessageController {
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler; this.pushNotificationScheduler = pushNotificationScheduler;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
this.messageDeliveryScheduler = messageDeliveryScheduler; this.messageDeliveryScheduler = messageDeliveryScheduler;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
@ -332,15 +313,15 @@ public class MessageController {
throw new WebApplicationException(Status.FORBIDDEN); throw new WebApplicationException(Status.FORBIDDEN);
} }
final Optional<Account> destination; final Optional<Account> maybeDestination;
if (!isSyncMessage) { if (!isSyncMessage) {
destination = accountsManager.getByServiceIdentifier(destinationIdentifier); maybeDestination = accountsManager.getByServiceIdentifier(destinationIdentifier);
} else { } else {
destination = source.map(AuthenticatedDevice::getAccount); maybeDestination = source.map(AuthenticatedDevice::getAccount);
} }
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam( final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(
context, source, destination, Optional.of(destinationIdentifier)); context, source, maybeDestination, Optional.of(destinationIdentifier));
final Optional<byte[]> reportSpamToken; final Optional<byte[]> reportSpamToken;
switch (spamCheck) { switch (spamCheck) {
case final SpamChecker.Spam spam: return spam.response(); case final SpamChecker.Spam spam: return spam.response();
@ -376,11 +357,11 @@ public class MessageController {
// Stories will be checked by the client; we bypass access checks here for stories. // Stories will be checked by the client; we bypass access checks here for stories.
} else if (groupSendToken != null) { } else if (groupSendToken != null) {
checkGroupSendToken(List.of(destinationIdentifier.toLibsignal()), groupSendToken); checkGroupSendToken(List.of(destinationIdentifier.toLibsignal()), groupSendToken);
if (destination.isEmpty()) { if (maybeDestination.isEmpty()) {
throw new NotFoundException(); throw new NotFoundException();
} }
} else { } else {
OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, destination, OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, maybeDestination,
destinationIdentifier); destinationIdentifier);
} }
@ -389,20 +370,20 @@ public class MessageController {
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify // We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from // we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
// these requests. // these requests.
if (isStory && destination.isEmpty()) { if (isStory && maybeDestination.isEmpty()) {
return Response.ok(new SendMessageResponse(needsSync)).build(); return Response.ok(new SendMessageResponse(needsSync)).build();
} }
// if destination is empty we would either throw an exception in OptionalAccess.verify when isStory is false // if destination is empty we would either throw an exception in OptionalAccess.verify when isStory is false
// or else return a 200 response when isStory is true. // or else return a 200 response when isStory is true.
assert destination.isPresent(); final Account destination = maybeDestination.orElseThrow();
if (source.isPresent() && !isSyncMessage) { if (source.isPresent() && !isSyncMessage) {
checkMessageRateLimit(source.get(), destination.get(), userAgent); checkMessageRateLimit(source.get(), destination, userAgent);
} }
if (isStory) { if (isStory) {
rateLimiters.getStoriesLimiter().validate(destination.get().getUuid()); rateLimiters.getStoriesLimiter().validate(destination.getUuid());
} }
final Set<Byte> excludedDeviceIds; final Set<Byte> excludedDeviceIds;
@ -413,15 +394,32 @@ public class MessageController {
excludedDeviceIds = Collections.emptySet(); excludedDeviceIds = Collections.emptySet();
} }
DestinationDeviceValidator.validateCompleteDeviceList(destination.get(), final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
messages.messages().stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()), .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
try {
return message.toEnvelope(
destinationIdentifier,
source.map(AuthenticatedDevice::getAccount).orElse(null),
source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null),
messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(),
isStory,
messages.urgent(),
reportSpamToken.orElse(null));
} catch (final IllegalArgumentException e) {
logger.warn("Received bad envelope type {} from {}", message.type(), userAgent);
throw new BadRequestException(e);
}
}));
DestinationDeviceValidator.validateCompleteDeviceList(destination,
messagesByDeviceId.keySet(),
excludedDeviceIds); excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination.get(), DestinationDeviceValidator.validateRegistrationIds(destination,
messages.messages(), messages.messages(),
IncomingMessage::destinationDeviceId, IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId, IncomingMessage::destinationRegistrationId,
destination.get().getPhoneNumberIdentifier().equals(destinationIdentifier.uuid())); destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final String authType; final String authType;
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) { if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
@ -434,31 +432,15 @@ public class MessageController {
authType = AUTH_TYPE_ACCESS_KEY; authType = AUTH_TYPE_ACCESS_KEY;
} }
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), messageSender.sendMessages(destination, messagesByDeviceId);
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE), Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())),
Tag.of(SENDER_TYPE_TAG_NAME, senderType), Tag.of(SENDER_TYPE_TAG_NAME, senderType),
Tag.of(AUTH_TYPE_TAG_NAME, authType), Tag.of(AUTH_TYPE_TAG_NAME, authType),
Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name())); Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name())))
.increment(messagesByDeviceId.size());
for (final IncomingMessage incomingMessage : messages.messages()) {
destination.get().getDevice(incomingMessage.destinationDeviceId())
.ifPresent(destinationDevice -> {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendIndividualMessage(
source,
destination.get(),
destinationDevice,
destinationIdentifier,
messages.timestamp(),
messages.online(),
isStory,
messages.urgent(),
incomingMessage,
userAgent,
reportSpamToken);
});
}
return Response.ok(new SendMessageResponse(needsSync)).build(); return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {
@ -481,34 +463,6 @@ public class MessageController {
} }
} }
/**
* Build mapping of service IDs to resolved accounts and device/registration IDs
*/
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.switchIfEmpty(Flux.error(BadRequestException::new))
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
.flatMap(
t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1()))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.map(
account ->
new MultiRecipientDeliveryData(
t.getT1(),
account,
t.getT2(),
t.getT2().getDevicesAndRegistrationIds().collect(
Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second))))
// IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e),
MAX_FETCH_ACCOUNT_CONCURRENCY)
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
.block();
}
@Timed @Timed
@Path("/multi_recipient") @Path("/multi_recipient")
@PUT @PUT
@ -565,6 +519,32 @@ public class MessageController {
throw new BadRequestException("Illegal timestamp"); throw new BadRequestException("Illegal timestamp");
} }
if (multiRecipientMessage.getRecipients().isEmpty()) {
throw new BadRequestException("Recipient list is empty");
}
// Verify that the message isn't too large before performing more expensive validations
multiRecipientMessage.getRecipients().values().forEach(recipient ->
validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, userAgent));
// Check that the request is well-formed and doesn't contain repeated entries for the same device for the same
// recipient
{
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID];
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
Arrays.fill(usedDeviceIds, false);
for (final byte deviceId : recipient.getDevices()) {
if (usedDeviceIds[deviceId]) {
throw new BadRequestException();
}
usedDeviceIds[deviceId] = true;
}
}
}
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty(), Optional.empty()); final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty(), Optional.empty());
if (spamCheck instanceof final SpamChecker.Spam spam) { if (spamCheck instanceof final SpamChecker.Spam spam) {
return spam.response(); return spam.response();
@ -584,28 +564,43 @@ public class MessageController {
if (groupSendToken != null) { if (groupSendToken != null) {
// Group send endorsements are checked before we even attempt to resolve any accounts, since // Group send endorsements are checked before we even attempt to resolve any accounts, since
// the lists of service IDs in the envelope are all that we need to check against // the lists of service IDs in the envelope are all that we need to check against
checkGroupSendToken( checkGroupSendToken(multiRecipientMessage.getRecipients().keySet(), groupSendToken);
multiRecipientMessage.getRecipients().keySet(), groupSendToken);
} }
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory); // At this point, the caller has at least superficially provided the information needed to send a multi-recipient
// message. Attempt to resolve the destination service identifiers to Signal accounts.
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.flatMap(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier =
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
}, MAX_FETCH_ACCOUNT_CONCURRENCY)
.collectMap(Tuple2::getT1, Tuple2::getT2)
.blockOptional()
.orElse(Collections.emptyMap());
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above. // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them // Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
if (groupSendToken == null && !isStory) { if (groupSendToken == null && !isStory) {
checkAccessKeys(accessKeys, recipients.values()); checkAccessKeys(accessKeys, multiRecipientMessage, resolvedRecipients);
} }
// We might filter out all the recipients of a story (if none exist). // We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
if (isStory) { if (isStory) {
if (recipients.isEmpty()) { if (resolvedRecipients.isEmpty()) {
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build(); return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
} }
try { try {
CompletableFuture.allOf(recipients.values() CompletableFuture.allOf(resolvedRecipients.values()
.stream() .stream()
.map(recipient -> recipient.account().getUuid()) .map(account -> account.getIdentifier(IdentityType.ACI))
.map(accountIdentifier -> .map(accountIdentifier ->
rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture()) rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture())
.toList() .toList()
@ -620,31 +615,42 @@ public class MessageController {
} }
} }
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
recipients.values().forEach(recipient -> {
final Account account = recipient.account(); multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
// can just skip them.
return;
}
final Account account = resolvedRecipients.get(recipient);
try { try {
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
Collections.emptySet()); Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds( DestinationDeviceValidator.validateRegistrationIds(
account, account,
recipient.deviceIdToRegistrationId().entrySet(), deviceIdsToRegistrationIds.entrySet(),
Map.Entry<Byte, Short>::getKey, Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()), e -> Integer.valueOf(e.getValue()),
recipient.serviceIdentifier().identityType() == IdentityType.PNI); serviceId instanceof ServiceId.Pni);
} catch (MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {
accountMismatchedDevices.add( accountMismatchedDevices.add(
new AccountMismatchedDevices( new AccountMismatchedDevices(
recipient.serviceIdentifier(), ServiceIdentifier.fromLibsignal(serviceId),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) { } catch (final StaleDevicesException e) {
accountStaleDevices.add( accountStaleDevices.add(
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices()))); new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
} }
}); });
if (!accountMismatchedDevices.isEmpty()) { if (!accountMismatchedDevices.isEmpty()) {
return Response return Response
.status(409) .status(409)
@ -670,39 +676,30 @@ public class MessageController {
} }
try { try {
final byte[] sharedMrmKey = messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage); messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, timestamp, isStory, online, isUrgent).get();
CompletableFuture.allOf( multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
recipients.values().stream() if (!resolvedRecipients.containsKey(recipient)) {
.flatMap(recipientData -> { // We skipped sending to this recipient because we're sending a story and couldn't resolve the recipient to
final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of( // an existing account; don't increment the counter for this recipient.
UserAgentTagUtil.getPlatformTag(userAgent), return;
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI), }
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED),
Tag.of(AUTH_TYPE_TAG_NAME, authType),
Tag.of(IDENTITY_TYPE_TAG_NAME, recipientData.serviceIdentifier().identityType().name())));
validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipientData.recipient()), true, userAgent); final String identityType = switch (serviceId) {
case ServiceId.Aci ignored -> "ACI";
case ServiceId.Pni ignored -> "PNI";
default -> "unknown";
};
return recipientData.deviceIdToRegistrationId().keySet().stream().map( Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of(
deviceId -> CompletableFuture.runAsync( UserAgentTagUtil.getPlatformTag(userAgent),
() -> { Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI),
final Account destinationAccount = recipientData.account(); Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient()); Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED),
Tag.of(AUTH_TYPE_TAG_NAME, authType),
// we asserted this must exist in validateCompleteDeviceList Tag.of(IDENTITY_TYPE_TAG_NAME, identityType)))
final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow(); .increment(recipient.getDevices().length);
});
sentMessageCounter.increment();
sendCommonPayloadMessage(
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp,
online, isStory, isUrgent, payload, sharedMrmKey);
},
multiRecipientMessageExecutor));
})
.toArray(CompletableFuture[]::new))
.get();
} catch (InterruptedException e) { } catch (InterruptedException e) {
logger.error("interrupted while delivering multi-recipient messages", e); logger.error("interrupted while delivering multi-recipient messages", e);
throw new InternalServerErrorException("interrupted during delivery"); throw new InternalServerErrorException("interrupted during delivery");
@ -729,29 +726,21 @@ public class MessageController {
private void checkAccessKeys( private void checkAccessKeys(
final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys, final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys,
final Collection<MultiRecipientDeliveryData> destinations) { final SealedSenderMultiRecipientMessage multiRecipientMessage,
final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH; final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients) {
if (multiRecipientMessage.getRecipients().keySet().stream()
.anyMatch(serviceId -> serviceId instanceof ServiceId.Pni)) {
if (destinations.stream()
.anyMatch(destination -> IdentityType.PNI.equals(destination.serviceIdentifier.identityType()))) {
throw new WebApplicationException("Multi-recipient messages must be addressed to ACI service IDs", throw new WebApplicationException("Multi-recipient messages must be addressed to ACI service IDs",
Status.UNAUTHORIZED); Status.UNAUTHORIZED);
} }
final byte[] combinedUnidentifiedAccessKeys = destinations.stream() try {
.map(MultiRecipientDeliveryData::account) if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(resolvedRecipients.values(), accessKeys.getAccessKeys())) {
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess)) throw new WebApplicationException(Status.UNAUTHORIZED);
.map(account -> }
account.getUnidentifiedAccessKey() } catch (final IllegalArgumentException ignored) {
.filter(b -> b.length == keyLength)
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
.reduce(new byte[keyLength],
(a, b) -> {
final byte[] xor = new byte[keyLength];
IntStream.range(0, keyLength).forEach(i -> xor[i] = (byte) (a[i] ^ b[i]));
return xor;
});
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) {
throw new WebApplicationException(Status.UNAUTHORIZED); throw new WebApplicationException(Status.UNAUTHORIZED);
} }
} }
@ -912,65 +901,6 @@ public class MessageController {
.build(); .build();
} }
private void sendIndividualMessage(
Optional<AuthenticatedDevice> source,
Account destinationAccount,
Device destinationDevice,
ServiceIdentifier destinationIdentifier,
long timestamp,
boolean online,
boolean story,
boolean urgent,
IncomingMessage incomingMessage,
String userAgentString,
Optional<byte[]> spamReportToken) {
final Envelope envelope;
try {
final Account sourceAccount = source.map(AuthenticatedDevice::getAccount).orElse(null);
final Byte sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
envelope = incomingMessage.toEnvelope(
destinationIdentifier,
sourceAccount,
sourceDeviceId,
timestamp == 0 ? System.currentTimeMillis() : timestamp,
story,
urgent,
spamReportToken.orElse(null));
} catch (final IllegalArgumentException e) {
logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString);
throw new BadRequestException(e);
}
messageSender.sendMessage(destinationAccount, destinationDevice, envelope, online);
}
private void sendCommonPayloadMessage(Account destinationAccount,
Device destinationDevice,
ServiceIdentifier serviceIdentifier,
long timestamp,
boolean online,
boolean story,
boolean urgent,
byte[] payload,
byte[] sharedMrmKey) {
final Envelope.Builder messageBuilder = Envelope.newBuilder();
final long serverTimestamp = System.currentTimeMillis();
messageBuilder
.setType(Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp)
.setStory(story)
.setUrgent(urgent)
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
}
private void checkMessageRateLimit(AuthenticatedDevice source, Account destination, String userAgent) private void checkMessageRateLimit(AuthenticatedDevice source, Account destination, String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());
@ -1020,15 +950,4 @@ public class MessageController {
throw new BadRequestException("reserved envelope type"); throw new BadRequestException("reserved envelope type");
} }
} }
public static Optional<byte[]> getMessageContent(IncomingMessage message) {
if (StringUtils.isEmpty(message.content())) return Optional.empty();
try {
return Optional.of(Base64.getDecoder().decode(message.content()));
} catch (IllegalArgumentException e) {
logger.debug("Bad B64", e);
return Optional.empty();
}
}
} }

View File

@ -55,11 +55,7 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<SealedSe
try { try {
final SealedSenderMultiRecipientMessage message = SealedSenderMultiRecipientMessage.parse(fullMessage); final SealedSenderMultiRecipientMessage message = SealedSenderMultiRecipientMessage.parse(fullMessage);
RECIPIENT_COUNT_DISTRIBUTION.record(message.getRecipients().keySet().size()); RECIPIENT_COUNT_DISTRIBUTION.record(message.getRecipients().size());
if (message.getRecipients().values().stream().anyMatch(r -> message.messageSizeForRecipient(r) > MAX_MESSAGE_SIZE)) {
throw new BadRequestException("message payload too large");
}
return message; return message;
} catch (InvalidMessageException | InvalidVersionException e) { } catch (InvalidMessageException | InvalidVersionException e) {
throw new BadRequestException(e); throw new BadRequestException(e);

View File

@ -9,9 +9,14 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util;
/** /**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
@ -42,26 +47,82 @@ public class MessageSender {
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
} }
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean online) { /**
final boolean destinationPresent = messagesManager.insert(account.getUuid(), * Sends messages to devices associated with the given destination account. If a destination device has a valid push
device.getId(), * notification token and does not have an active connection to a Signal server, then this method will also send a
online ? message.toBuilder().setEphemeral(true).build() : message); * push notification to that device to announce the availability of new messages.
*
* @param account the account to which to send messages
* @param messagesByDeviceId a map of device IDs to message payloads
*/
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !online) { if (!destinationPresent && !message.getEphemeral()) {
try { try {
pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent()); pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) { } catch (final NotPushRegisteredException ignored) {
} }
} }
Metrics.counter(SEND_COUNTER_NAME, Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, getDeliveryChannelName(device), CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(online), EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent), CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()), STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId())) SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment(); .increment();
});
}
/**
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to
* announce the availability of new messages.
*
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
* @param resolvedRecipients a map of recipients to resolved Signal accounts
* @param clientTimestamp the time at which the sender reports the message was sent
* @param isStory {@code true} if the message is a story or {@code false otherwise}
* @param isEphemeral {@code true} if the message should only be delivered to devices with active connections or
* {@code false otherwise}
* @param isUrgent {@code true} if the message is urgent or {@code false otherwise}
*
* @return a future that completes when all messages have been inserted into delivery queues
*/
public CompletableFuture<Void> sendMultiRecipientMessage(final SealedSenderMultiRecipientMessage multiRecipientMessage,
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients,
final long clientTimestamp,
final boolean isStory,
final boolean isEphemeral,
final boolean isUrgent) {
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
isStory, isEphemeral, isUrgent)
.thenAccept(clientPresenceByAccountAndDevice ->
clientPresenceByAccountAndDevice.forEach((account, clientPresenceByDeviceId) ->
clientPresenceByDeviceId.forEach((deviceId, clientPresent) -> {
if (!clientPresent && !isEphemeral) {
try {
pushNotificationManager.sendNewMessageNotification(account, deviceId, isUrgent);
} catch (final NotPushRegisteredException ignored) {
}
}
Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME,
account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(isEphemeral),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(isUrgent),
STORY_TAG_NAME, String.valueOf(isStory),
SEALED_SENDER_TAG_NAME, String.valueOf(true))
.increment();
})))
.thenRun(Util.NOOP);
} }
@VisibleForTesting @VisibleForTesting

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
@ -43,21 +44,21 @@ public class ReceiptSender {
try { try {
accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse( accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse(
destinationAccount -> { destinationAccount -> {
final Envelope.Builder message = Envelope.newBuilder() final Envelope message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setSourceServiceId(sourceIdentifier.toServiceIdentifierString()) .setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice(sourceDeviceId) .setSourceDevice(sourceDeviceId)
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString()) .setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
.setClientTimestamp(messageId) .setClientTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false); .setUrgent(false)
.build();
for (final Device destinationDevice : destinationAccount.getDevices()) { try {
try { messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
messageSender.sendMessage(destinationAccount, destinationDevice, message.build(), false); .collect(Collectors.toMap(Device::getId, ignored -> message)));
} catch (final Exception e) { } catch (final Exception e) {
logger.warn("Could not send delivery receipt", e); logger.warn("Could not send delivery receipt", e);
}
} }
}, },
() -> logger.info("No longer registered: {}", destinationIdentifier) () -> logger.info("No longer registered: {}", destinationIdentifier)

View File

@ -4,8 +4,8 @@
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -13,10 +13,10 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@ -115,40 +115,39 @@ public class ChangeNumberManager {
private void sendDeviceMessages(final Account account, final List<IncomingMessage> deviceMessages) { private void sendDeviceMessages(final Account account, final List<IncomingMessage> deviceMessages) {
try { try {
deviceMessages.forEach(message -> final long serverTimestamp = System.currentTimeMillis();
sendMessageToSelf(account, account.getDevice(message.destinationDeviceId()), message));
} catch (RuntimeException e) { messageSender.sendMessages(account, deviceMessages.stream()
.filter(message -> getMessageContent(message).isPresent())
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(getMessageContent(message).orElseThrow()))
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.setEphemeral(false)
.build())));
} catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e; throw e;
} }
} }
@VisibleForTesting private static Optional<byte[]> getMessageContent(final IncomingMessage message) {
void sendMessageToSelf( if (StringUtils.isEmpty(message.content())) {
Account sourceAndDestinationAccount, Optional<Device> destinationDevice, IncomingMessage message) { logger.warn("Message has no content");
Optional<byte[]> contents = MessageController.getMessageContent(message); return Optional.empty();
if (contents.isEmpty()) {
logger.debug("empty message contents sending to self, ignoring");
return;
} else if (destinationDevice.isEmpty()) {
logger.debug("destination device not present");
return;
} }
final long serverTimestamp = System.currentTimeMillis(); try {
final Envelope envelope = Envelope.newBuilder() return Optional.of(Base64.getDecoder().decode(message.content()));
.setType(Envelope.Type.forNumber(message.type())) } catch (final IllegalArgumentException e) {
.setClientTimestamp(serverTimestamp) logger.warn("Failed to parse message content", e);
.setServerTimestamp(serverTimestamp) return Optional.empty();
.setDestinationServiceId( }
new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(contents.get()))
.setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.build();
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
} }
} }

View File

@ -203,22 +203,28 @@ public class MessagesCache {
this.unlockQueueScript = unlockQueueScript; this.unlockQueueScript = unlockQueueScript;
} }
public boolean insert(final UUID messageGuid, public CompletableFuture<Boolean> insert(final UUID messageGuid,
final UUID destinationAccountIdentifier, final UUID destinationAccountIdentifier,
final byte destinationDeviceId, final byte destinationDeviceId,
final MessageProtos.Envelope message) { final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build(); final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build();
return insertTimer.record(() -> insertScript.execute(destinationAccountIdentifier, destinationDeviceId, messageWithGuid)); final Timer.Sample sample = Timer.start();
return insertScript.executeAsync(destinationAccountIdentifier, destinationDeviceId, messageWithGuid)
.whenComplete((ignored, throwable) -> sample.stop(insertTimer));
} }
public byte[] insertSharedMultiRecipientMessagePayload( public CompletableFuture<byte[]> insertSharedMultiRecipientMessagePayload(
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return insertSharedMrmPayloadTimer.record(() -> {
final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID()); final Timer.Sample sample = Timer.start();
insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage);
return sharedMrmKey; final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID());
});
return insertMrmScript.executeAsync(sharedMrmKey, sealedSenderMultiRecipientMessage)
.thenApply(ignored -> sharedMrmKey)
.whenComplete((ignored, throwable) -> sample.stop(insertSharedMrmPayloadTimer));
} }
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice, public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,

View File

@ -12,6 +12,7 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientEvent; import org.whispersystems.textsecuregcm.push.ClientEvent;
import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent; import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent;
@ -44,7 +45,7 @@ class MessagesCacheInsertScript {
* @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false} * @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false}
* otherwise * otherwise
*/ */
boolean execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { CompletableFuture<Boolean> executeAsync(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid(); assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp(); assert envelope.hasServerTimestamp();
@ -62,6 +63,7 @@ class MessagesCacheInsertScript {
NEW_MESSAGE_EVENT_BYTES // eventPayload NEW_MESSAGE_EVENT_BYTES // eventPayload
)); ));
return (boolean) insertScript.executeBinary(keys, args); return insertScript.executeBinaryAsync(keys, args)
.thenApply(result -> (boolean) result);
} }
} }

View File

@ -9,9 +9,11 @@ import io.lettuce.core.ScriptOutputType;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.util.Util;
/** /**
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as * Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
@ -31,7 +33,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
ScriptOutputType.INTEGER); ScriptOutputType.INTEGER);
} }
void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) { CompletableFuture<Void> executeAsync(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
final List<byte[]> keys = List.of( final List<byte[]> keys = List.of(
sharedMrmKey // sharedMrmKey sharedMrmKey // sharedMrmKey
); );
@ -47,6 +49,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
} }
}); });
script.executeBinary(keys, args); return script.executeBinaryAsync(keys, args)
.thenRun(Util.NOOP);
} }
} }

View File

@ -6,18 +6,23 @@ package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.protobuf.ByteString;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.reactivestreams.Publisher; import org.reactivestreams.Publisher;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
@ -25,6 +30,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.observability.micrometer.Micrometer; import reactor.core.observability.micrometer.Micrometer;
@ -48,41 +55,120 @@ public class MessagesManager {
private final MessagesCache messagesCache; private final MessagesCache messagesCache;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ExecutorService messageDeletionExecutor; private final ExecutorService messageDeletionExecutor;
private final Clock clock;
public MessagesManager( public MessagesManager(
final MessagesDynamoDb messagesDynamoDb, final MessagesDynamoDb messagesDynamoDb,
final MessagesCache messagesCache, final MessagesCache messagesCache,
final ReportMessageManager reportMessageManager, final ReportMessageManager reportMessageManager,
final ExecutorService messageDeletionExecutor) { final ExecutorService messageDeletionExecutor,
final Clock clock) {
this.messagesDynamoDb = messagesDynamoDb; this.messagesDynamoDb = messagesDynamoDb;
this.messagesCache = messagesCache; this.messagesCache = messagesCache;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.messageDeletionExecutor = messageDeletionExecutor; this.messageDeletionExecutor = messageDeletionExecutor;
this.clock = clock;
} }
/** /**
* Inserts a message into a target device's message queue and notifies registered listeners that a new message is * Inserts messages into the message queues for devices associated with the identified account.
* available.
* *
* @param destinationUuid the account identifier for the destination queue * @param accountIdentifier the account identifier for the destination queue
* @param destinationDeviceId the device ID for the destination queue * @param messagesByDeviceId a map of device IDs to messages
* @param message the message to insert into the queue
* *
* @return {@code true} if the destination device is "present" (i.e. has an active event listener) or {@code false} * @return a map of device IDs to a device's presence state (i.e. if the device has an active event listener)
* otherwise
* *
* @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager * @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager
*/ */
public boolean insert(final UUID destinationUuid, final byte destinationDeviceId, final Envelope message) { public Map<Byte, Boolean> insert(final UUID accountIdentifier, final Map<Byte, Envelope> messagesByDeviceId) {
final UUID messageGuid = UUID.randomUUID(); return insertAsync(accountIdentifier, messagesByDeviceId).join();
}
final boolean destinationPresent = messagesCache.insert(messageGuid, destinationUuid, destinationDeviceId, message); private CompletableFuture<Map<Byte, Boolean>> insertAsync(final UUID accountIdentifier, final Map<Byte, Envelope> messagesByDeviceId) {
final Map<Byte, Boolean> devicePresenceById = new ConcurrentHashMap<>();
if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) { return CompletableFuture.allOf(messagesByDeviceId.entrySet().stream()
reportMessageManager.store(message.getSourceServiceId(), messageGuid); .map(deviceIdAndMessage -> {
} final byte deviceId = deviceIdAndMessage.getKey();
final Envelope message = deviceIdAndMessage.getValue();
final UUID messageGuid = UUID.randomUUID();
return destinationPresent; return messagesCache.insert(messageGuid, accountIdentifier, deviceId, message)
.thenAccept(present -> {
if (message.hasSourceServiceId() && !accountIdentifier.toString()
.equals(message.getSourceServiceId())) {
// Note that this is an asynchronous, best-effort, fire-and-forget operation
reportMessageManager.store(message.getSourceServiceId(), messageGuid);
}
devicePresenceById.put(deviceId, present);
});
})
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> devicePresenceById);
}
/**
* Inserts messages into the message queues for devices associated with the identified accounts.
*
* @param multiRecipientMessage the multi-recipient message to insert into destination queues
* @param resolvedRecipients a map of multi-recipient message {@code Recipient} entities to their corresponding
* Signal accounts; messages will not be delivered to unresolved recipients
* @param clientTimestamp the timestamp for the message as reported by the sending party
* @param isStory {@code true} if the given message is a story or {@code false} otherwise
* @param isEphemeral {@code true} if the given message should only be delivered to devices with active
* connections to a Signal server or {@code false} otherwise
* @param isUrgent {@code true} if the given message is urgent or {@code false} otherwise
*
* @return a map of accounts to maps of device IDs to a device's presence state (i.e. if the device has an active
* event listener)
*
* @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager
*/
public CompletableFuture<Map<Account, Map<Byte, Boolean>>> insertMultiRecipientMessage(
final SealedSenderMultiRecipientMessage multiRecipientMessage,
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients,
final long clientTimestamp,
final boolean isStory,
final boolean isEphemeral,
final boolean isUrgent) {
final long serverTimestamp = clock.millis();
return insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
.thenCompose(sharedMrmKey -> {
final Envelope prototypeMessage = Envelope.newBuilder()
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp)
.setServerTimestamp(serverTimestamp)
.setStory(isStory)
.setEphemeral(isEphemeral)
.setUrgent(isUrgent)
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build();
final Map<Account, Map<Byte, Boolean>> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>();
return CompletableFuture.allOf(multiRecipientMessage.getRecipients().entrySet().stream()
.filter(serviceIdAndRecipient -> resolvedRecipients.containsKey(serviceIdAndRecipient.getValue()))
.map(serviceIdAndRecipient -> {
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
final SealedSenderMultiRecipientMessage.Recipient recipient = serviceIdAndRecipient.getValue();
final byte[] devices = recipient.getDevices();
return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI),
IntStream.range(0, devices.length).mapToObj(i -> devices[i])
.collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder()
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.build())))
.thenAccept(clientPresenceByDeviceId ->
clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient),
clientPresenceByDeviceId));
})
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> clientPresenceByAccountAndDevice);
});
} }
public CompletableFuture<Boolean> mayHavePersistedMessages(final UUID destinationUuid, final Device destinationDevice) { public CompletableFuture<Boolean> mayHavePersistedMessages(final UUID destinationUuid, final Device destinationDevice) {
@ -217,7 +303,7 @@ public class MessagesManager {
* @return a key where the shared data is stored * @return a key where the shared data is stored
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript * @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
*/ */
public byte[] insertSharedMultiRecipientMessagePayload( private CompletableFuture<byte[]> insertSharedMultiRecipientMessagePayload(
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) { final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage); return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage);
} }

View File

@ -3,6 +3,8 @@ package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
@ -11,6 +13,7 @@ import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
@ -20,6 +23,7 @@ public class ReportMessageDynamoDb {
static final String ATTR_TTL = "E"; static final String ATTR_TTL = "E";
private final DynamoDbClient db; private final DynamoDbClient db;
private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName; private final String tableName;
private final Duration ttl; private final Duration ttl;
@ -30,20 +34,26 @@ public class ReportMessageDynamoDb {
.distributionStatisticExpiry(Duration.ofDays(1)) .distributionStatisticExpiry(Duration.ofDays(1))
.register(Metrics.globalRegistry); .register(Metrics.globalRegistry);
public ReportMessageDynamoDb(final DynamoDbClient dynamoDB, final String tableName, final Duration ttl) { public ReportMessageDynamoDb(final DynamoDbClient dynamoDB,
final DynamoDbAsyncClient dynamoDbAsyncClient,
final String tableName,
final Duration ttl) {
this.db = dynamoDB; this.db = dynamoDB;
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName; this.tableName = tableName;
this.ttl = ttl; this.ttl = ttl;
} }
public void store(byte[] hash) { public CompletableFuture<Void> store(byte[] hash) {
db.putItem(PutItemRequest.builder() return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName) .tableName(tableName)
.item(Map.of( .item(Map.of(
KEY_HASH, AttributeValues.fromByteArray(hash), KEY_HASH, AttributeValues.fromByteArray(hash),
ATTR_TTL, AttributeValues.fromLong(Instant.now().plus(ttl).getEpochSecond()) ATTR_TTL, AttributeValues.fromLong(Instant.now().plus(ttl).getEpochSecond())
)) ))
.build()); .build())
.thenRun(Util.NOOP);
} }
public boolean remove(byte[] hash) { public boolean remove(byte[] hash) {

View File

@ -54,11 +54,8 @@ public class ReportMessageManager {
} }
public void store(String sourceAci, UUID messageGuid) { public void store(String sourceAci, UUID messageGuid) {
try { try {
Objects.requireNonNull(sourceAci); reportMessageDynamoDb.store(hash(messageGuid, Objects.requireNonNull(sourceAci)));
reportMessageDynamoDb.store(hash(messageGuid, sourceAci));
} catch (final Exception e) { } catch (final Exception e) {
logger.warn("Failed to store hash", e); logger.warn("Failed to store hash", e);
} }

View File

@ -22,13 +22,15 @@ public class DestinationDeviceValidator {
/** /**
* @see #validateRegistrationIds(Account, Stream, boolean) * @see #validateRegistrationIds(Account, Stream, boolean)
*/ */
public static <T> void validateRegistrationIds(final Account account, final Collection<T> messages, public static <T> void validateRegistrationIds(final Account account,
Function<T, Byte> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity) final Collection<T> messages,
throws StaleDevicesException { Function<T, Byte> getDeviceId,
Function<T, Integer> getRegistrationId,
boolean usePhoneNumberIdentity) throws StaleDevicesException {
validateRegistrationIds(account, validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))), messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
usePhoneNumberIdentity); usePhoneNumberIdentity);
} }
/** /**

View File

@ -217,13 +217,13 @@ record CommandDependencies(
MessagesCache messagesCache = new MessagesCache(messagesCluster, MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(), configuration.getDynamoDbTables().getReportMessage().getTableName(),
configuration.getReportMessageConfiguration().getReportTtl()); configuration.getReportMessageConfiguration().getReportTtl());
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
configuration.getReportMessageConfiguration().getCounterTtl()); configuration.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager, messageDeletionExecutor); reportMessageManager, messageDeletionExecutor, Clock.systemUTC());
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName()); configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =

View File

@ -10,23 +10,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.UUID; import java.util.UUID;
import org.apache.commons.lang3.RandomStringUtils; import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -49,17 +51,21 @@ class MessageSenderTest {
@CartesianTest @CartesianTest
void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent, void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean onlineMessage, @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException { @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
final boolean expectPushNotificationAttempt = !clientPresent && !onlineMessage; final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID; final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class); final Account account = mock(Account.class);
final Device device = mock(Device.class); final Device device = mock(Device.class);
final MessageProtos.Envelope message = generateRandomMessage(); final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
.setEphemeral(ephemeral)
.setUrgent(urgent)
.build();
when(account.getUuid()).thenReturn(accountIdentifier); when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
@ -72,18 +78,61 @@ class MessageSenderTest {
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean()); .when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
} }
when(messagesManager.insert(eq(accountIdentifier), eq(deviceId), any())).thenReturn(clientPresent); when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, onlineMessage)); assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
final MessageProtos.Envelope expectedMessage = onlineMessage final MessageProtos.Envelope expectedMessage = ephemeral
? message.toBuilder().setEphemeral(true).build() ? message.toBuilder().setEphemeral(true).build()
: message.toBuilder().build(); : message.toBuilder().build();
verify(messagesManager).insert(accountIdentifier, deviceId, expectedMessage); verify(messagesManager).insert(accountIdentifier, Map.of(deviceId, expectedMessage));
if (expectPushNotificationAttempt) { if (expectPushNotificationAttempt) {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, expectedMessage.getUrgent()); verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
} else {
verifyNoInteractions(pushNotificationManager);
}
}
@CartesianTest
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(device.getId()).thenReturn(deviceId);
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
} else {
doThrow(NotPushRegisteredException.class)
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
}
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
Collections.emptyMap(),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
.join());
if (expectPushNotificationAttempt) {
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
} else { } else {
verifyNoInteractions(pushNotificationManager); verifyNoInteractions(pushNotificationManager);
} }
@ -123,14 +172,4 @@ class MessageSenderTest {
return arguments; return arguments;
} }
private MessageProtos.Envelope generateRandomMessage() {
return MessageProtos.Envelope.newBuilder()
.setClientTimestamp(System.currentTimeMillis())
.setServerTimestamp(System.currentTimeMillis())
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(UUID.randomUUID().toString())
.build();
}
} }

View File

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -104,7 +105,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); verify(messageSender, never()).sendMessages(eq(account), any());
} }
@Test @Test
@ -118,7 +119,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); verify(messageSender, never()).sendMessages(eq(account), any());
} }
@Test @Test
@ -155,10 +156,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@ -203,10 +209,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@ -249,10 +260,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@ -291,10 +307,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
@ -335,10 +356,15 @@ public class ChangeNumberManagerTest {
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));

View File

@ -84,7 +84,7 @@ class MessagePersisterIntegrationTest {
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC()); messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
messageDeletionExecutorService); messageDeletionExecutorService, Clock.systemUTC());
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor(); asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor();
@ -143,7 +143,7 @@ class MessagePersisterIntegrationTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message); messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message).join();
expectedMessages.add(message); expectedMessages.add(message);
} }

View File

@ -358,7 +358,7 @@ class MessagePersisterTest {
.setServerGuid(messageGuid.toString()) .setServerGuid(messageGuid.toString())
.build(); .build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope); messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join();
} }
} }

View File

@ -40,7 +40,7 @@ class MessagesCacheGetItemsScriptTest {
.setServerGuid(serverGuid) .setServerGuid(serverGuid)
.build(); .build();
insertScript.execute(destinationUuid, deviceId, envelope1); insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript( final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@ -41,7 +41,7 @@ class MessagesCacheInsertScriptTest {
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build(); .build();
insertScript.execute(destinationUuid, deviceId, envelope1); insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId)); assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
@ -50,11 +50,11 @@ class MessagesCacheInsertScriptTest {
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build(); .build();
insertScript.execute(destinationUuid, deviceId, envelope2); insertScript.executeAsync(destinationUuid, deviceId, envelope2);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId)); assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
insertScript.execute(destinationUuid, deviceId, envelope1); insertScript.executeAsync(destinationUuid, deviceId, envelope1);
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId), assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
"Messages with same GUID should be deduplicated"); "Messages with same GUID should be deduplicated");
@ -89,10 +89,10 @@ class MessagesCacheInsertScriptTest {
final MessagesCacheInsertScript insertScript = final MessagesCacheInsertScript insertScript =
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster()); new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() assertFalse(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond()) .setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build())); .build()).join());
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection = final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection(); REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
@ -100,9 +100,9 @@ class MessagesCacheInsertScriptTest {
pubSubClusterConnection.usePubSubConnection(connection -> pubSubClusterConnection.usePubSubConnection(connection ->
connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId))); connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId)));
assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder() assertTrue(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond()) .setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build())); .build()).join());
} }
} }

View File

@ -6,7 +6,9 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import io.lettuce.core.RedisCommandExecutionException; import io.lettuce.core.RedisCommandExecutionException;
import java.util.ArrayList; import java.util.ArrayList;
@ -14,8 +16,10 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import io.lettuce.core.RedisException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
@ -39,8 +43,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(destinations)); MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum(); final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
@ -82,15 +86,17 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)); MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)).join();
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class, final CompletionException completionException = assertThrows(CompletionException.class,
() -> insertMrmScript.execute(sharedMrmKey, () -> insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
Device.PRIMARY_ID))); Device.PRIMARY_ID)).join());
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage()); assertInstanceOf(RedisException.class, completionException.getCause());
assertTrue(completionException.getCause().getMessage()
.contains(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS));
} }
} }

View File

@ -34,7 +34,7 @@ class MessagesCacheRemoveByGuidScriptTest {
.setServerGuid(serverGuid.toString()) .setServerGuid(serverGuid.toString())
.build(); .build();
insertScript.execute(destinationUuid, deviceId, envelope1); insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript( final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@ -35,7 +35,7 @@ class MessagesCacheRemoveQueueScriptTest {
.setServerGuid(UUID.randomUUID().toString()) .setServerGuid(UUID.randomUUID().toString())
.build(); .build();
insertScript.execute(destinationUuid, deviceId, envelope1); insertScript.executeAsync(destinationUuid, deviceId, envelope1);
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript( final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());

View File

@ -41,8 +41,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
MessagesCacheTest.generateRandomMrmMessage(destinations));
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript( final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
@ -103,8 +102,8 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
REDIS_CLUSTER_EXTENSION.getRedisCluster()); REDIS_CLUSTER_EXTENSION.getRedisCluster());
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID()); final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
insertMrmScript.execute(sharedMrmKey, insertMrmScript.executeAsync(sharedMrmKey,
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)); MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)).join();
sharedMrmKeys.add(sharedMrmKey); sharedMrmKeys.add(sharedMrmKey);
} }

View File

@ -122,7 +122,7 @@ class MessagesCacheTest {
void testInsert(final boolean sealedSender) { void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender))); generateRandomMessage(messageGuid, sealedSender))).join();
} }
@Test @Test
@ -130,8 +130,8 @@ class MessagesCacheTest {
final UUID duplicateGuid = UUID.randomUUID(); final UUID duplicateGuid = UUID.randomUUID();
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10) assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
.count() .count()
@ -149,7 +149,7 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID, final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS); DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
@ -175,12 +175,12 @@ class MessagesCacheTest {
for (final MessageProtos.Envelope message : messagesToRemove) { for (final MessageProtos.Envelope message : messagesToRemove) {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
message); message).join();
} }
for (final MessageProtos.Envelope message : messagesToPreserve) { for (final MessageProtos.Envelope message : messagesToPreserve) {
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
message); message).join();
} }
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
@ -197,7 +197,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
} }
@ -208,7 +208,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
} }
@ -223,7 +223,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
assertEquals(expectedOldestTimestamp, assertEquals(expectedOldestTimestamp,
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block()); messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
expectedMessages.add(message); expectedMessages.add(message);
@ -248,7 +248,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
expectedMessages.add(message); expectedMessages.add(message);
} }
@ -262,7 +262,7 @@ class MessagesCacheTest {
final UUID message1Guid = UUID.randomUUID(); final UUID message1Guid = UUID.randomUUID();
final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender); final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender);
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1); messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1).join();
final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID, final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID,
1); 1);
assertEquals(List.of(message1), get1); assertEquals(List.of(message1), get1);
@ -272,7 +272,7 @@ class MessagesCacheTest {
final UUID message2Guid = UUID.randomUUID(); final UUID message2Guid = UUID.randomUUID();
final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender); final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender);
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2); messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2).join();
assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1)); assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1));
} }
@ -287,7 +287,7 @@ class MessagesCacheTest {
for (int i = 0; i < messageCount; i++) { for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
expectedMessages.add(message); expectedMessages.add(message);
} }
@ -295,7 +295,7 @@ class MessagesCacheTest {
final UUID ephemeralMessageGuid = UUID.randomUUID(); final UUID ephemeralMessageGuid = UUID.randomUUID();
final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true) final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true)
.toBuilder().setEphemeral(true).build(); .toBuilder().setEphemeral(true).build();
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage); messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage).join();
final Clock cacheClock; final Clock cacheClock;
if (expectStale) { if (expectStale) {
@ -352,7 +352,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
} }
} }
@ -372,7 +372,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message); messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
} }
} }
@ -404,7 +404,7 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender)); generateRandomMessage(messageGuid, sealedSender)).join();
final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID); final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty()); assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
@ -427,7 +427,7 @@ class MessagesCacheTest {
final byte[] sharedMrmDataKey; final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) { if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
} else { } else {
sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8); sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8);
} }
@ -440,7 +440,7 @@ class MessagesCacheTest {
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.clearContent() .clearContent()
.build(); .build();
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message); messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message).join();
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster() assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey))); .withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
@ -487,13 +487,13 @@ class MessagesCacheTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
new AciServiceIdentifier(destinationUuid), true); new AciServiceIdentifier(destinationUuid), true);
messagesCache.insert(messageGuid, destinationUuid, deviceId, message); messagesCache.insert(messageGuid, destinationUuid, deviceId, message).join();
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId); final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
final byte[] sharedMrmDataKey; final byte[] sharedMrmDataKey;
if (sharedMrmKeyPresent) { if (sharedMrmKeyPresent) {
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm); sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
} else { } else {
sharedMrmDataKey = new byte[]{1}; sharedMrmDataKey = new byte[]{1};
} }
@ -505,7 +505,7 @@ class MessagesCacheTest {
.clearContent() .clearContent()
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey)) .setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
.build(); .build();
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage); messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join();
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100); final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);

View File

@ -7,22 +7,42 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets;
import java.time.Instant; import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
import org.whispersystems.textsecuregcm.util.TestClock;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
class MessagesManagerTest { class MessagesManagerTest {
@ -31,8 +51,15 @@ class MessagesManagerTest {
private final MessagesCache messagesCache = mock(MessagesCache.class); private final MessagesCache messagesCache = mock(MessagesCache.class);
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager, Executors.newSingleThreadExecutor()); reportMessageManager, Executors.newSingleThreadExecutor(), CLOCK);
@BeforeEach
void setUp() {
when(messagesCache.insert(any(), any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(true));
}
@Test @Test
void insert() { void insert() {
@ -43,7 +70,7 @@ class MessagesManagerTest {
final UUID destinationUuid = UUID.randomUUID(); final UUID destinationUuid = UUID.randomUUID();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message); messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, message));
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class)); verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
@ -51,11 +78,113 @@ class MessagesManagerTest {
.setSourceServiceId(destinationUuid.toString()) .setSourceServiceId(destinationUuid.toString())
.build(); .build();
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage); messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, syncMessage));
verifyNoMoreInteractions(reportMessageManager); verifyNoMoreInteractions(reportMessageManager);
} }
@Test
void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException {
final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final ServiceIdentifier unresolvedAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Account singleDeviceAccount = mock(Account.class);
final Account multiDeviceAccount = mock(Account.class);
when(singleDeviceAccount.getIdentifier(IdentityType.ACI))
.thenReturn(singleDeviceAccountAciServiceIdentifier.uuid());
when(multiDeviceAccount.getIdentifier(IdentityType.ACI))
.thenReturn(multiDeviceAccountAciServiceIdentifier.uuid());
final byte[] multiRecipientMessageBytes = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(
new TestRecipient(singleDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]),
new TestRecipient(multiDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 2, new byte[48]),
new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]),
new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]),
new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48])
));
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes);
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = new HashMap<>();
multiRecipientMessage.getRecipients().forEach(((serviceId, recipient) -> {
if (serviceId.getRawUUID().equals(singleDeviceAccountAciServiceIdentifier.uuid()) ||
serviceId.getRawUUID().equals(singleDeviceAccountPniServiceIdentifier.uuid())) {
resolvedRecipients.put(recipient, singleDeviceAccount);
} else if (serviceId.getRawUUID().equals(multiDeviceAccountAciServiceIdentifier.uuid())) {
resolvedRecipients.put(recipient, multiDeviceAccount);
}
}));
final Map<Account, Map<Byte, Boolean>> expectedPresenceByAccountAndDeviceId = Map.of(
singleDeviceAccount, Map.of(Device.PRIMARY_ID, true),
multiDeviceAccount, Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
);
final Map<UUID, Map<Byte, Boolean>> presenceByAccountIdentifierAndDeviceId = Map.of(
singleDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, true),
multiDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
);
final byte[] sharedMrmKey = "shared-mrm-key".getBytes(StandardCharsets.UTF_8);
when(messagesCache.insertSharedMultiRecipientMessagePayload(multiRecipientMessage))
.thenReturn(CompletableFuture.completedFuture(sharedMrmKey));
when(messagesCache.insert(any(), any(), anyByte(), any()))
.thenAnswer(invocation -> {
final UUID accountIdentifier = invocation.getArgument(1);
final byte deviceId = invocation.getArgument(2);
return CompletableFuture.completedFuture(
presenceByAccountIdentifierAndDeviceId.getOrDefault(accountIdentifier, Collections.emptyMap())
.getOrDefault(deviceId, false));
});
final long clientTimestamp = System.currentTimeMillis();
final boolean isStory = ThreadLocalRandom.current().nextBoolean();
final boolean isEphemeral = ThreadLocalRandom.current().nextBoolean();
final boolean isUrgent = ThreadLocalRandom.current().nextBoolean();
final Envelope prototypeExpectedMessage = Envelope.newBuilder()
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
.setClientTimestamp(clientTimestamp)
.setServerTimestamp(CLOCK.millis())
.setStory(isStory)
.setEphemeral(isEphemeral)
.setUrgent(isUrgent)
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
.build();
assertEquals(expectedPresenceByAccountAndDeviceId,
messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join());
verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache).insert(any(),
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
eq(Device.PRIMARY_ID),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache).insert(any(),
eq(multiDeviceAccountAciServiceIdentifier.uuid()),
eq((byte) (Device.PRIMARY_ID + 1)),
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
verify(messagesCache, never()).insert(any(),
eq(unresolvedAccountAciServiceIdentifier.uuid()),
anyByte(),
any());
}
@ParameterizedTest @ParameterizedTest
@CsvSource({ @CsvSource({
"false, false, false", "false, false, false",

View File

@ -29,6 +29,7 @@ class ReportMessageDynamoDbTest {
void setUp() { void setUp() {
this.reportMessageDynamoDb = new ReportMessageDynamoDb( this.reportMessageDynamoDb = new ReportMessageDynamoDb(
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.REPORT_MESSAGES.tableName(), Tables.REPORT_MESSAGES.tableName(),
Duration.ofDays(1)); Duration.ofDays(1));
} }
@ -44,8 +45,8 @@ class ReportMessageDynamoDbTest {
() -> assertFalse(reportMessageDynamoDb.remove(hash2)) () -> assertFalse(reportMessageDynamoDb.remove(hash2))
); );
reportMessageDynamoDb.store(hash1); reportMessageDynamoDb.store(hash1).join();
reportMessageDynamoDb.store(hash2); reportMessageDynamoDb.store(hash2).join();
assertAll("both hashes should be found", assertAll("both hashes should be found",
() -> assertTrue(reportMessageDynamoDb.remove(hash1)), () -> assertTrue(reportMessageDynamoDb.remove(hash1)),

View File

@ -18,6 +18,7 @@ import static org.mockito.Mockito.when;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
@ -68,8 +69,8 @@ class ReportMessageManagerTest {
verify(reportMessageDynamoDb).store(any()); verify(reportMessageDynamoDb).store(any());
doThrow(RuntimeException.class) when(reportMessageDynamoDb.store(any()))
.when(reportMessageDynamoDb).store(any()); .thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid)); assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid));
} }

View File

@ -0,0 +1,92 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.nio.ByteBuffer;
import java.util.List;
public class MultiRecipientMessageHelper {
private MultiRecipientMessageHelper() {
}
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients) {
return generateMultiRecipientMessage(recipients, 32);
}
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients, final int sharedPayloadSize) {
if (sharedPayloadSize < 32) {
throw new IllegalArgumentException("Shared payload size must be at least 32 bytes");
}
final ByteBuffer buffer = ByteBuffer.allocate(payloadSize(recipients, sharedPayloadSize));
// first write the header
buffer.put((byte) 0x23); // version byte
// count varint
writeVarint(buffer, recipients.size());
recipients.forEach(recipient -> {
buffer.put(recipient.uuid().toFixedWidthByteArray());
assert recipient.deviceIds().length == recipient.registrationIds().length;
for (int i = 0; i < recipient.deviceIds().length; i++) {
final int hasMore = i == recipient.deviceIds().length - 1 ? 0 : 0x8000;
buffer.put(recipient.deviceIds()[i]); // device id (1 byte)
buffer.putShort((short) (recipient.registrationIds()[i] | hasMore)); // registration id (2 bytes)
}
buffer.put(recipient.perRecipientKeyMaterial()); // key material (48 bytes)
});
// now write the actual message body (empty for now)
writeVarint(buffer, sharedPayloadSize);
buffer.put(new byte[sharedPayloadSize]);
return buffer.array();
}
private static void writeVarint(final ByteBuffer buffer, long n) {
if (n < 0) {
throw new IllegalArgumentException();
}
while (n >= 0x80) {
buffer.put ((byte) (n & 0x7F | 0x80));
n >>= 7;
}
buffer.put((byte) (n & 0x7F));
}
private static int payloadSize(final List<TestRecipient> recipients, final int sharedPayloadSize) {
final int fixedBytesPerRecipient = 17 // Service identifier length
+ 48; // Per-recipient key material
final int bytesForDevices = 3 * recipients.stream()
.mapToInt(recipient -> recipient.deviceIds().length)
.sum();
return 1 // Version byte
+ varintLength(recipients.size())
+ (recipients.size() * fixedBytesPerRecipient)
+ bytesForDevices
+ varintLength(sharedPayloadSize)
+ sharedPayloadSize;
}
private static int varintLength(long n) {
int length = 0;
while (n >= 0x80) {
length += 1;
n >>= 7;
}
return length + 1;
}
}

View File

@ -0,0 +1,22 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public record TestRecipient(ServiceIdentifier uuid,
byte[] deviceIds,
int[] registrationIds,
byte[] perRecipientKeyMaterial) {
public TestRecipient(ServiceIdentifier uuid,
byte deviceId,
int registrationId,
byte[] perRecipientKeyMaterial) {
this(uuid, new byte[]{deviceId}, new int[]{registrationId}, perRecipientKeyMaterial);
}
}

View File

@ -132,7 +132,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) { void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class), mock(PushNotificationScheduler.class),
@ -164,7 +164,7 @@ class WebSocketConnectionIntegrationTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope); expectedMessages.add(envelope);
} }
@ -220,7 +220,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessagesClientClosed() { void testProcessStoredMessagesClientClosed() {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class), mock(PushNotificationScheduler.class),
@ -253,7 +253,7 @@ class WebSocketConnectionIntegrationTest {
for (int i = 0; i < cachedMessageCount; i++) { for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope); expectedMessages.add(envelope);
} }
@ -289,7 +289,7 @@ class WebSocketConnectionIntegrationTest {
void testProcessStoredMessagesSendFutureTimeout() { void testProcessStoredMessagesSendFutureTimeout() {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
new MessageMetrics(), new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class), mock(PushNotificationScheduler.class),
@ -323,7 +323,7 @@ class WebSocketConnectionIntegrationTest {
for (int i = 0; i < cachedMessageCount; i++) { for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
expectedMessages.add(envelope); expectedMessages.add(envelope);
} }