Clarify roles/responsibilities of components in the message-handling pathway
This commit is contained in:
parent
282bcf6f34
commit
48ada8e8ca
|
@ -431,7 +431,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
config.getDynamoDbTables().getRemoteConfig().getTableName());
|
config.getDynamoDbTables().getRemoteConfig().getTableName());
|
||||||
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient,
|
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbClient,
|
||||||
config.getDynamoDbTables().getPushChallenge().getTableName());
|
config.getDynamoDbTables().getPushChallenge().getTableName());
|
||||||
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
|
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
|
||||||
config.getDynamoDbTables().getReportMessage().getTableName(),
|
config.getDynamoDbTables().getReportMessage().getTableName(),
|
||||||
config.getReportMessageConfiguration().getReportTtl());
|
config.getReportMessageConfiguration().getReportTtl());
|
||||||
RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords(
|
RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords(
|
||||||
|
@ -618,7 +618,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
|
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
|
||||||
config.getReportMessageConfiguration().getCounterTtl());
|
config.getReportMessageConfiguration().getCounterTtl());
|
||||||
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager,
|
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager,
|
||||||
messageDeletionAsyncExecutor);
|
messageDeletionAsyncExecutor, Clock.systemUTC());
|
||||||
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
|
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
|
||||||
config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
|
config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
|
||||||
ClientPublicKeysManager clientPublicKeysManager =
|
ClientPublicKeysManager clientPublicKeysManager =
|
||||||
|
@ -1128,7 +1128,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
new KeyTransparencyController(keyTransparencyServiceClient),
|
new KeyTransparencyController(keyTransparencyServiceClient),
|
||||||
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
|
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
|
||||||
accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler,
|
accountsManager, messagesManager, phoneNumberIdentifiers, pushNotificationManager, pushNotificationScheduler,
|
||||||
reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, clientReleaseManager,
|
reportMessageManager, messageDeliveryScheduler, clientReleaseManager,
|
||||||
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor,
|
dynamicConfigurationManager, zkSecretParams, spamChecker, messageMetrics, messageDeliveryLoopMonitor,
|
||||||
Clock.systemUTC()),
|
Clock.systemUTC()),
|
||||||
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
|
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
|
||||||
|
|
|
@ -7,6 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
|
||||||
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import java.security.MessageDigest;
|
import java.security.MessageDigest;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class UnidentifiedAccessUtil {
|
public class UnidentifiedAccessUtil {
|
||||||
|
|
||||||
|
@ -31,4 +34,42 @@ public class UnidentifiedAccessUtil {
|
||||||
.map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey))
|
.map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey))
|
||||||
.orElse(false);
|
.orElse(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks whether an action (e.g. sending a message or retrieving pre-keys) may be taken on the collection of target
|
||||||
|
* accounts by an actor presenting the given combined unidentified access key.
|
||||||
|
*
|
||||||
|
* @param targetAccounts the accounts on which an actor wishes to take an action
|
||||||
|
* @param combinedUnidentifiedAccessKey the unidentified access key presented by the actor
|
||||||
|
*
|
||||||
|
* @return {@code true} if an actor presenting the given unidentified access key has permission to take an action on
|
||||||
|
* the target accounts or {@code false} otherwise
|
||||||
|
*/
|
||||||
|
public static boolean checkUnidentifiedAccess(final Collection<Account> targetAccounts, final byte[] combinedUnidentifiedAccessKey) {
|
||||||
|
return MessageDigest.isEqual(getCombinedUnidentifiedAccessKey(targetAccounts), combinedUnidentifiedAccessKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates a combined unidentified access key for the given collection of accounts.
|
||||||
|
*
|
||||||
|
* @param accounts the accounts from which to derive a combined unidentified access key
|
||||||
|
* @return a combined unidentified access key
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if one or more of the given accounts had an unidentified access key with an
|
||||||
|
* unexpected length
|
||||||
|
*/
|
||||||
|
public static byte[] getCombinedUnidentifiedAccessKey(final Collection<Account> accounts) {
|
||||||
|
return accounts.stream()
|
||||||
|
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
|
||||||
|
.map(account ->
|
||||||
|
account.getUnidentifiedAccessKey()
|
||||||
|
.filter(b -> b.length == UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH)
|
||||||
|
.orElseThrow(IllegalArgumentException::new))
|
||||||
|
.reduce(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH],
|
||||||
|
(a, b) -> {
|
||||||
|
final byte[] xor = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
|
||||||
|
IntStream.range(0, UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH).forEach(i -> xor[i] = (byte) (a[i] ^ b[i]));
|
||||||
|
return xor;
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,8 @@ import static com.codahale.metrics.MetricRegistry.name;
|
||||||
import com.codahale.metrics.annotation.Timed;
|
import com.codahale.metrics.annotation.Timed;
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.net.HttpHeaders;
|
import com.google.common.net.HttpHeaders;
|
||||||
import com.google.protobuf.ByteString;
|
|
||||||
import io.dropwizard.auth.Auth;
|
import io.dropwizard.auth.Auth;
|
||||||
import io.dropwizard.util.DataSize;
|
import io.dropwizard.util.DataSize;
|
||||||
import io.micrometer.core.instrument.Counter;
|
|
||||||
import io.micrometer.core.instrument.DistributionSummary;
|
import io.micrometer.core.instrument.DistributionSummary;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.Tag;
|
import io.micrometer.core.instrument.Tag;
|
||||||
|
@ -47,33 +45,26 @@ import jakarta.ws.rs.core.Context;
|
||||||
import jakarta.ws.rs.core.MediaType;
|
import jakarta.ws.rs.core.MediaType;
|
||||||
import jakarta.ws.rs.core.Response;
|
import jakarta.ws.rs.core.Response;
|
||||||
import jakarta.ws.rs.core.Response.Status;
|
import jakarta.ws.rs.core.Response.Status;
|
||||||
import java.security.MessageDigest;
|
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Base64;
|
import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CancellationException;
|
import java.util.concurrent.CancellationException;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.function.Predicate;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
import javax.annotation.Nonnull;
|
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.glassfish.jersey.server.ManagedAsync;
|
import org.glassfish.jersey.server.ManagedAsync;
|
||||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
|
|
||||||
import org.signal.libsignal.protocol.ServiceId;
|
import org.signal.libsignal.protocol.ServiceId;
|
||||||
import org.signal.libsignal.protocol.util.Pair;
|
import org.signal.libsignal.protocol.util.Pair;
|
||||||
import org.signal.libsignal.zkgroup.ServerSecretParams;
|
import org.signal.libsignal.zkgroup.ServerSecretParams;
|
||||||
|
@ -135,6 +126,7 @@ import org.whispersystems.websocket.auth.ReadOnly;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.core.scheduler.Scheduler;
|
import reactor.core.scheduler.Scheduler;
|
||||||
|
import reactor.util.function.Tuple2;
|
||||||
import reactor.util.function.Tuples;
|
import reactor.util.function.Tuples;
|
||||||
|
|
||||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||||
|
@ -142,14 +134,6 @@ import reactor.util.function.Tuples;
|
||||||
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
|
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
|
||||||
public class MessageController {
|
public class MessageController {
|
||||||
|
|
||||||
|
|
||||||
private record MultiRecipientDeliveryData(
|
|
||||||
ServiceIdentifier serviceIdentifier,
|
|
||||||
Account account,
|
|
||||||
Recipient recipient,
|
|
||||||
Map<Byte, Short> deviceIdToRegistrationId) {
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
|
||||||
|
|
||||||
private final RateLimiters rateLimiters;
|
private final RateLimiters rateLimiters;
|
||||||
|
@ -162,7 +146,6 @@ public class MessageController {
|
||||||
private final PushNotificationManager pushNotificationManager;
|
private final PushNotificationManager pushNotificationManager;
|
||||||
private final PushNotificationScheduler pushNotificationScheduler;
|
private final PushNotificationScheduler pushNotificationScheduler;
|
||||||
private final ReportMessageManager reportMessageManager;
|
private final ReportMessageManager reportMessageManager;
|
||||||
private final ExecutorService multiRecipientMessageExecutor;
|
|
||||||
private final Scheduler messageDeliveryScheduler;
|
private final Scheduler messageDeliveryScheduler;
|
||||||
private final ClientReleaseManager clientReleaseManager;
|
private final ClientReleaseManager clientReleaseManager;
|
||||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||||
|
@ -229,7 +212,6 @@ public class MessageController {
|
||||||
PushNotificationManager pushNotificationManager,
|
PushNotificationManager pushNotificationManager,
|
||||||
PushNotificationScheduler pushNotificationScheduler,
|
PushNotificationScheduler pushNotificationScheduler,
|
||||||
ReportMessageManager reportMessageManager,
|
ReportMessageManager reportMessageManager,
|
||||||
@Nonnull ExecutorService multiRecipientMessageExecutor,
|
|
||||||
Scheduler messageDeliveryScheduler,
|
Scheduler messageDeliveryScheduler,
|
||||||
final ClientReleaseManager clientReleaseManager,
|
final ClientReleaseManager clientReleaseManager,
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
|
@ -248,7 +230,6 @@ public class MessageController {
|
||||||
this.pushNotificationManager = pushNotificationManager;
|
this.pushNotificationManager = pushNotificationManager;
|
||||||
this.pushNotificationScheduler = pushNotificationScheduler;
|
this.pushNotificationScheduler = pushNotificationScheduler;
|
||||||
this.reportMessageManager = reportMessageManager;
|
this.reportMessageManager = reportMessageManager;
|
||||||
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
|
|
||||||
this.messageDeliveryScheduler = messageDeliveryScheduler;
|
this.messageDeliveryScheduler = messageDeliveryScheduler;
|
||||||
this.clientReleaseManager = clientReleaseManager;
|
this.clientReleaseManager = clientReleaseManager;
|
||||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||||
|
@ -332,15 +313,15 @@ public class MessageController {
|
||||||
throw new WebApplicationException(Status.FORBIDDEN);
|
throw new WebApplicationException(Status.FORBIDDEN);
|
||||||
}
|
}
|
||||||
|
|
||||||
final Optional<Account> destination;
|
final Optional<Account> maybeDestination;
|
||||||
if (!isSyncMessage) {
|
if (!isSyncMessage) {
|
||||||
destination = accountsManager.getByServiceIdentifier(destinationIdentifier);
|
maybeDestination = accountsManager.getByServiceIdentifier(destinationIdentifier);
|
||||||
} else {
|
} else {
|
||||||
destination = source.map(AuthenticatedDevice::getAccount);
|
maybeDestination = source.map(AuthenticatedDevice::getAccount);
|
||||||
}
|
}
|
||||||
|
|
||||||
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(
|
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(
|
||||||
context, source, destination, Optional.of(destinationIdentifier));
|
context, source, maybeDestination, Optional.of(destinationIdentifier));
|
||||||
final Optional<byte[]> reportSpamToken;
|
final Optional<byte[]> reportSpamToken;
|
||||||
switch (spamCheck) {
|
switch (spamCheck) {
|
||||||
case final SpamChecker.Spam spam: return spam.response();
|
case final SpamChecker.Spam spam: return spam.response();
|
||||||
|
@ -376,11 +357,11 @@ public class MessageController {
|
||||||
// Stories will be checked by the client; we bypass access checks here for stories.
|
// Stories will be checked by the client; we bypass access checks here for stories.
|
||||||
} else if (groupSendToken != null) {
|
} else if (groupSendToken != null) {
|
||||||
checkGroupSendToken(List.of(destinationIdentifier.toLibsignal()), groupSendToken);
|
checkGroupSendToken(List.of(destinationIdentifier.toLibsignal()), groupSendToken);
|
||||||
if (destination.isEmpty()) {
|
if (maybeDestination.isEmpty()) {
|
||||||
throw new NotFoundException();
|
throw new NotFoundException();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, destination,
|
OptionalAccess.verify(source.map(AuthenticatedDevice::getAccount), accessKey, maybeDestination,
|
||||||
destinationIdentifier);
|
destinationIdentifier);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,20 +370,20 @@ public class MessageController {
|
||||||
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
|
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
|
||||||
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
|
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
|
||||||
// these requests.
|
// these requests.
|
||||||
if (isStory && destination.isEmpty()) {
|
if (isStory && maybeDestination.isEmpty()) {
|
||||||
return Response.ok(new SendMessageResponse(needsSync)).build();
|
return Response.ok(new SendMessageResponse(needsSync)).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
// if destination is empty we would either throw an exception in OptionalAccess.verify when isStory is false
|
// if destination is empty we would either throw an exception in OptionalAccess.verify when isStory is false
|
||||||
// or else return a 200 response when isStory is true.
|
// or else return a 200 response when isStory is true.
|
||||||
assert destination.isPresent();
|
final Account destination = maybeDestination.orElseThrow();
|
||||||
|
|
||||||
if (source.isPresent() && !isSyncMessage) {
|
if (source.isPresent() && !isSyncMessage) {
|
||||||
checkMessageRateLimit(source.get(), destination.get(), userAgent);
|
checkMessageRateLimit(source.get(), destination, userAgent);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isStory) {
|
if (isStory) {
|
||||||
rateLimiters.getStoriesLimiter().validate(destination.get().getUuid());
|
rateLimiters.getStoriesLimiter().validate(destination.getUuid());
|
||||||
}
|
}
|
||||||
|
|
||||||
final Set<Byte> excludedDeviceIds;
|
final Set<Byte> excludedDeviceIds;
|
||||||
|
@ -413,15 +394,32 @@ public class MessageController {
|
||||||
excludedDeviceIds = Collections.emptySet();
|
excludedDeviceIds = Collections.emptySet();
|
||||||
}
|
}
|
||||||
|
|
||||||
DestinationDeviceValidator.validateCompleteDeviceList(destination.get(),
|
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
|
||||||
messages.messages().stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()),
|
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
|
||||||
|
try {
|
||||||
|
return message.toEnvelope(
|
||||||
|
destinationIdentifier,
|
||||||
|
source.map(AuthenticatedDevice::getAccount).orElse(null),
|
||||||
|
source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null),
|
||||||
|
messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(),
|
||||||
|
isStory,
|
||||||
|
messages.urgent(),
|
||||||
|
reportSpamToken.orElse(null));
|
||||||
|
} catch (final IllegalArgumentException e) {
|
||||||
|
logger.warn("Received bad envelope type {} from {}", message.type(), userAgent);
|
||||||
|
throw new BadRequestException(e);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
DestinationDeviceValidator.validateCompleteDeviceList(destination,
|
||||||
|
messagesByDeviceId.keySet(),
|
||||||
excludedDeviceIds);
|
excludedDeviceIds);
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(destination.get(),
|
DestinationDeviceValidator.validateRegistrationIds(destination,
|
||||||
messages.messages(),
|
messages.messages(),
|
||||||
IncomingMessage::destinationDeviceId,
|
IncomingMessage::destinationDeviceId,
|
||||||
IncomingMessage::destinationRegistrationId,
|
IncomingMessage::destinationRegistrationId,
|
||||||
destination.get().getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
|
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
|
||||||
|
|
||||||
final String authType;
|
final String authType;
|
||||||
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
|
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
|
||||||
|
@ -434,31 +432,15 @@ public class MessageController {
|
||||||
authType = AUTH_TYPE_ACCESS_KEY;
|
authType = AUTH_TYPE_ACCESS_KEY;
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
messageSender.sendMessages(destination, messagesByDeviceId);
|
||||||
|
|
||||||
|
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||||
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
|
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
|
||||||
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())),
|
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())),
|
||||||
Tag.of(SENDER_TYPE_TAG_NAME, senderType),
|
Tag.of(SENDER_TYPE_TAG_NAME, senderType),
|
||||||
Tag.of(AUTH_TYPE_TAG_NAME, authType),
|
Tag.of(AUTH_TYPE_TAG_NAME, authType),
|
||||||
Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name()));
|
Tag.of(IDENTITY_TYPE_TAG_NAME, destinationIdentifier.identityType().name())))
|
||||||
|
.increment(messagesByDeviceId.size());
|
||||||
for (final IncomingMessage incomingMessage : messages.messages()) {
|
|
||||||
destination.get().getDevice(incomingMessage.destinationDeviceId())
|
|
||||||
.ifPresent(destinationDevice -> {
|
|
||||||
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
|
|
||||||
sendIndividualMessage(
|
|
||||||
source,
|
|
||||||
destination.get(),
|
|
||||||
destinationDevice,
|
|
||||||
destinationIdentifier,
|
|
||||||
messages.timestamp(),
|
|
||||||
messages.online(),
|
|
||||||
isStory,
|
|
||||||
messages.urgent(),
|
|
||||||
incomingMessage,
|
|
||||||
userAgent,
|
|
||||||
reportSpamToken);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return Response.ok(new SendMessageResponse(needsSync)).build();
|
return Response.ok(new SendMessageResponse(needsSync)).build();
|
||||||
} catch (final MismatchedDevicesException e) {
|
} catch (final MismatchedDevicesException e) {
|
||||||
|
@ -481,34 +463,6 @@ public class MessageController {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build mapping of service IDs to resolved accounts and device/registration IDs
|
|
||||||
*/
|
|
||||||
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
|
|
||||||
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
|
|
||||||
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
|
|
||||||
.switchIfEmpty(Flux.error(BadRequestException::new))
|
|
||||||
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
|
|
||||||
.flatMap(
|
|
||||||
t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1()))
|
|
||||||
.flatMap(Mono::justOrEmpty)
|
|
||||||
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
|
|
||||||
.map(
|
|
||||||
account ->
|
|
||||||
new MultiRecipientDeliveryData(
|
|
||||||
t.getT1(),
|
|
||||||
account,
|
|
||||||
t.getT2(),
|
|
||||||
t.getT2().getDevicesAndRegistrationIds().collect(
|
|
||||||
Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second))))
|
|
||||||
// IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device
|
|
||||||
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e),
|
|
||||||
MAX_FETCH_ACCOUNT_CONCURRENCY)
|
|
||||||
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
|
|
||||||
.block();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Timed
|
@Timed
|
||||||
@Path("/multi_recipient")
|
@Path("/multi_recipient")
|
||||||
@PUT
|
@PUT
|
||||||
|
@ -565,6 +519,32 @@ public class MessageController {
|
||||||
throw new BadRequestException("Illegal timestamp");
|
throw new BadRequestException("Illegal timestamp");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (multiRecipientMessage.getRecipients().isEmpty()) {
|
||||||
|
throw new BadRequestException("Recipient list is empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the message isn't too large before performing more expensive validations
|
||||||
|
multiRecipientMessage.getRecipients().values().forEach(recipient ->
|
||||||
|
validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipient), true, userAgent));
|
||||||
|
|
||||||
|
// Check that the request is well-formed and doesn't contain repeated entries for the same device for the same
|
||||||
|
// recipient
|
||||||
|
{
|
||||||
|
final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID];
|
||||||
|
|
||||||
|
for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) {
|
||||||
|
Arrays.fill(usedDeviceIds, false);
|
||||||
|
|
||||||
|
for (final byte deviceId : recipient.getDevices()) {
|
||||||
|
if (usedDeviceIds[deviceId]) {
|
||||||
|
throw new BadRequestException();
|
||||||
|
}
|
||||||
|
|
||||||
|
usedDeviceIds[deviceId] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty(), Optional.empty());
|
final SpamChecker.SpamCheckResult spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty(), Optional.empty());
|
||||||
if (spamCheck instanceof final SpamChecker.Spam spam) {
|
if (spamCheck instanceof final SpamChecker.Spam spam) {
|
||||||
return spam.response();
|
return spam.response();
|
||||||
|
@ -584,28 +564,43 @@ public class MessageController {
|
||||||
if (groupSendToken != null) {
|
if (groupSendToken != null) {
|
||||||
// Group send endorsements are checked before we even attempt to resolve any accounts, since
|
// Group send endorsements are checked before we even attempt to resolve any accounts, since
|
||||||
// the lists of service IDs in the envelope are all that we need to check against
|
// the lists of service IDs in the envelope are all that we need to check against
|
||||||
checkGroupSendToken(
|
checkGroupSendToken(multiRecipientMessage.getRecipients().keySet(), groupSendToken);
|
||||||
multiRecipientMessage.getRecipients().keySet(), groupSendToken);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
|
// At this point, the caller has at least superficially provided the information needed to send a multi-recipient
|
||||||
|
// message. Attempt to resolve the destination service identifiers to Signal accounts.
|
||||||
|
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients =
|
||||||
|
Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
|
||||||
|
.flatMap(serviceIdAndRecipient -> {
|
||||||
|
final ServiceIdentifier serviceIdentifier =
|
||||||
|
ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
|
||||||
|
|
||||||
|
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
|
||||||
|
.flatMap(Mono::justOrEmpty)
|
||||||
|
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
|
||||||
|
.map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account));
|
||||||
|
}, MAX_FETCH_ACCOUNT_CONCURRENCY)
|
||||||
|
.collectMap(Tuple2::getT1, Tuple2::getT2)
|
||||||
|
.blockOptional()
|
||||||
|
.orElse(Collections.emptyMap());
|
||||||
|
|
||||||
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
|
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
|
||||||
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
|
// Group send endorsements are checked earlier; for stories, we don't check permissions at all because only clients check them
|
||||||
if (groupSendToken == null && !isStory) {
|
if (groupSendToken == null && !isStory) {
|
||||||
checkAccessKeys(accessKeys, recipients.values());
|
checkAccessKeys(accessKeys, multiRecipientMessage, resolvedRecipients);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We might filter out all the recipients of a story (if none exist).
|
// We might filter out all the recipients of a story (if none exist).
|
||||||
// In this case there is no error so we should just return 200 now.
|
// In this case there is no error so we should just return 200 now.
|
||||||
if (isStory) {
|
if (isStory) {
|
||||||
if (recipients.isEmpty()) {
|
if (resolvedRecipients.isEmpty()) {
|
||||||
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
|
return Response.ok(new SendMultiRecipientMessageResponse(List.of())).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
CompletableFuture.allOf(recipients.values()
|
CompletableFuture.allOf(resolvedRecipients.values()
|
||||||
.stream()
|
.stream()
|
||||||
.map(recipient -> recipient.account().getUuid())
|
.map(account -> account.getIdentifier(IdentityType.ACI))
|
||||||
.map(accountIdentifier ->
|
.map(accountIdentifier ->
|
||||||
rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture())
|
rateLimiters.getStoriesLimiter().validateAsync(accountIdentifier).toCompletableFuture())
|
||||||
.toList()
|
.toList()
|
||||||
|
@ -620,31 +615,42 @@ public class MessageController {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
||||||
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
||||||
recipients.values().forEach(recipient -> {
|
|
||||||
final Account account = recipient.account();
|
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
|
||||||
|
if (!resolvedRecipients.containsKey(recipient)) {
|
||||||
|
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
|
||||||
|
// can just skip them.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
final Account account = resolvedRecipients.get(recipient);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(),
|
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
|
||||||
|
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
|
||||||
|
|
||||||
|
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
|
||||||
Collections.emptySet());
|
Collections.emptySet());
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(
|
DestinationDeviceValidator.validateRegistrationIds(
|
||||||
account,
|
account,
|
||||||
recipient.deviceIdToRegistrationId().entrySet(),
|
deviceIdsToRegistrationIds.entrySet(),
|
||||||
Map.Entry<Byte, Short>::getKey,
|
Map.Entry<Byte, Short>::getKey,
|
||||||
e -> Integer.valueOf(e.getValue()),
|
e -> Integer.valueOf(e.getValue()),
|
||||||
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
|
serviceId instanceof ServiceId.Pni);
|
||||||
} catch (MismatchedDevicesException e) {
|
} catch (final MismatchedDevicesException e) {
|
||||||
accountMismatchedDevices.add(
|
accountMismatchedDevices.add(
|
||||||
new AccountMismatchedDevices(
|
new AccountMismatchedDevices(
|
||||||
recipient.serviceIdentifier(),
|
ServiceIdentifier.fromLibsignal(serviceId),
|
||||||
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
||||||
} catch (StaleDevicesException e) {
|
} catch (final StaleDevicesException e) {
|
||||||
accountStaleDevices.add(
|
accountStaleDevices.add(
|
||||||
new AccountStaleDevices(recipient.serviceIdentifier(), new StaleDevices(e.getStaleDevices())));
|
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!accountMismatchedDevices.isEmpty()) {
|
if (!accountMismatchedDevices.isEmpty()) {
|
||||||
return Response
|
return Response
|
||||||
.status(409)
|
.status(409)
|
||||||
|
@ -670,39 +676,30 @@ public class MessageController {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
final byte[] sharedMrmKey = messagesManager.insertSharedMultiRecipientMessagePayload(multiRecipientMessage);
|
messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, timestamp, isStory, online, isUrgent).get();
|
||||||
|
|
||||||
CompletableFuture.allOf(
|
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
|
||||||
recipients.values().stream()
|
if (!resolvedRecipients.containsKey(recipient)) {
|
||||||
.flatMap(recipientData -> {
|
// We skipped sending to this recipient because we're sending a story and couldn't resolve the recipient to
|
||||||
final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of(
|
// an existing account; don't increment the counter for this recipient.
|
||||||
UserAgentTagUtil.getPlatformTag(userAgent),
|
return;
|
||||||
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI),
|
}
|
||||||
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
|
|
||||||
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED),
|
|
||||||
Tag.of(AUTH_TYPE_TAG_NAME, authType),
|
|
||||||
Tag.of(IDENTITY_TYPE_TAG_NAME, recipientData.serviceIdentifier().identityType().name())));
|
|
||||||
|
|
||||||
validateContentLength(multiRecipientMessage.messageSizeForRecipient(recipientData.recipient()), true, userAgent);
|
final String identityType = switch (serviceId) {
|
||||||
|
case ServiceId.Aci ignored -> "ACI";
|
||||||
|
case ServiceId.Pni ignored -> "PNI";
|
||||||
|
default -> "unknown";
|
||||||
|
};
|
||||||
|
|
||||||
return recipientData.deviceIdToRegistrationId().keySet().stream().map(
|
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of(
|
||||||
deviceId -> CompletableFuture.runAsync(
|
UserAgentTagUtil.getPlatformTag(userAgent),
|
||||||
() -> {
|
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_MULTI),
|
||||||
final Account destinationAccount = recipientData.account();
|
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
|
||||||
final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient());
|
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED),
|
||||||
|
Tag.of(AUTH_TYPE_TAG_NAME, authType),
|
||||||
// we asserted this must exist in validateCompleteDeviceList
|
Tag.of(IDENTITY_TYPE_TAG_NAME, identityType)))
|
||||||
final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow();
|
.increment(recipient.getDevices().length);
|
||||||
|
});
|
||||||
sentMessageCounter.increment();
|
|
||||||
sendCommonPayloadMessage(
|
|
||||||
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp,
|
|
||||||
online, isStory, isUrgent, payload, sharedMrmKey);
|
|
||||||
},
|
|
||||||
multiRecipientMessageExecutor));
|
|
||||||
})
|
|
||||||
.toArray(CompletableFuture[]::new))
|
|
||||||
.get();
|
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
logger.error("interrupted while delivering multi-recipient messages", e);
|
logger.error("interrupted while delivering multi-recipient messages", e);
|
||||||
throw new InternalServerErrorException("interrupted during delivery");
|
throw new InternalServerErrorException("interrupted during delivery");
|
||||||
|
@ -729,29 +726,21 @@ public class MessageController {
|
||||||
|
|
||||||
private void checkAccessKeys(
|
private void checkAccessKeys(
|
||||||
final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys,
|
final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys,
|
||||||
final Collection<MultiRecipientDeliveryData> destinations) {
|
final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||||
final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH;
|
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients) {
|
||||||
|
|
||||||
|
if (multiRecipientMessage.getRecipients().keySet().stream()
|
||||||
|
.anyMatch(serviceId -> serviceId instanceof ServiceId.Pni)) {
|
||||||
|
|
||||||
if (destinations.stream()
|
|
||||||
.anyMatch(destination -> IdentityType.PNI.equals(destination.serviceIdentifier.identityType()))) {
|
|
||||||
throw new WebApplicationException("Multi-recipient messages must be addressed to ACI service IDs",
|
throw new WebApplicationException("Multi-recipient messages must be addressed to ACI service IDs",
|
||||||
Status.UNAUTHORIZED);
|
Status.UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
|
|
||||||
final byte[] combinedUnidentifiedAccessKeys = destinations.stream()
|
try {
|
||||||
.map(MultiRecipientDeliveryData::account)
|
if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(resolvedRecipients.values(), accessKeys.getAccessKeys())) {
|
||||||
.filter(Predicate.not(Account::isUnrestrictedUnidentifiedAccess))
|
throw new WebApplicationException(Status.UNAUTHORIZED);
|
||||||
.map(account ->
|
}
|
||||||
account.getUnidentifiedAccessKey()
|
} catch (final IllegalArgumentException ignored) {
|
||||||
.filter(b -> b.length == keyLength)
|
|
||||||
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)))
|
|
||||||
.reduce(new byte[keyLength],
|
|
||||||
(a, b) -> {
|
|
||||||
final byte[] xor = new byte[keyLength];
|
|
||||||
IntStream.range(0, keyLength).forEach(i -> xor[i] = (byte) (a[i] ^ b[i]));
|
|
||||||
return xor;
|
|
||||||
});
|
|
||||||
if (!MessageDigest.isEqual(combinedUnidentifiedAccessKeys, accessKeys.getAccessKeys())) {
|
|
||||||
throw new WebApplicationException(Status.UNAUTHORIZED);
|
throw new WebApplicationException(Status.UNAUTHORIZED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -912,65 +901,6 @@ public class MessageController {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void sendIndividualMessage(
|
|
||||||
Optional<AuthenticatedDevice> source,
|
|
||||||
Account destinationAccount,
|
|
||||||
Device destinationDevice,
|
|
||||||
ServiceIdentifier destinationIdentifier,
|
|
||||||
long timestamp,
|
|
||||||
boolean online,
|
|
||||||
boolean story,
|
|
||||||
boolean urgent,
|
|
||||||
IncomingMessage incomingMessage,
|
|
||||||
String userAgentString,
|
|
||||||
Optional<byte[]> spamReportToken) {
|
|
||||||
|
|
||||||
final Envelope envelope;
|
|
||||||
|
|
||||||
try {
|
|
||||||
final Account sourceAccount = source.map(AuthenticatedDevice::getAccount).orElse(null);
|
|
||||||
final Byte sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
|
|
||||||
envelope = incomingMessage.toEnvelope(
|
|
||||||
destinationIdentifier,
|
|
||||||
sourceAccount,
|
|
||||||
sourceDeviceId,
|
|
||||||
timestamp == 0 ? System.currentTimeMillis() : timestamp,
|
|
||||||
story,
|
|
||||||
urgent,
|
|
||||||
spamReportToken.orElse(null));
|
|
||||||
} catch (final IllegalArgumentException e) {
|
|
||||||
logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString);
|
|
||||||
throw new BadRequestException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
messageSender.sendMessage(destinationAccount, destinationDevice, envelope, online);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void sendCommonPayloadMessage(Account destinationAccount,
|
|
||||||
Device destinationDevice,
|
|
||||||
ServiceIdentifier serviceIdentifier,
|
|
||||||
long timestamp,
|
|
||||||
boolean online,
|
|
||||||
boolean story,
|
|
||||||
boolean urgent,
|
|
||||||
byte[] payload,
|
|
||||||
byte[] sharedMrmKey) {
|
|
||||||
|
|
||||||
final Envelope.Builder messageBuilder = Envelope.newBuilder();
|
|
||||||
final long serverTimestamp = System.currentTimeMillis();
|
|
||||||
|
|
||||||
messageBuilder
|
|
||||||
.setType(Type.UNIDENTIFIED_SENDER)
|
|
||||||
.setClientTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
|
|
||||||
.setServerTimestamp(serverTimestamp)
|
|
||||||
.setStory(story)
|
|
||||||
.setUrgent(urgent)
|
|
||||||
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
|
|
||||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey));
|
|
||||||
|
|
||||||
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkMessageRateLimit(AuthenticatedDevice source, Account destination, String userAgent)
|
private void checkMessageRateLimit(AuthenticatedDevice source, Account destination, String userAgent)
|
||||||
throws RateLimitExceededException {
|
throws RateLimitExceededException {
|
||||||
final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());
|
final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());
|
||||||
|
@ -1020,15 +950,4 @@ public class MessageController {
|
||||||
throw new BadRequestException("reserved envelope type");
|
throw new BadRequestException("reserved envelope type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Optional<byte[]> getMessageContent(IncomingMessage message) {
|
|
||||||
if (StringUtils.isEmpty(message.content())) return Optional.empty();
|
|
||||||
|
|
||||||
try {
|
|
||||||
return Optional.of(Base64.getDecoder().decode(message.content()));
|
|
||||||
} catch (IllegalArgumentException e) {
|
|
||||||
logger.debug("Bad B64", e);
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,11 +55,7 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<SealedSe
|
||||||
|
|
||||||
try {
|
try {
|
||||||
final SealedSenderMultiRecipientMessage message = SealedSenderMultiRecipientMessage.parse(fullMessage);
|
final SealedSenderMultiRecipientMessage message = SealedSenderMultiRecipientMessage.parse(fullMessage);
|
||||||
RECIPIENT_COUNT_DISTRIBUTION.record(message.getRecipients().keySet().size());
|
RECIPIENT_COUNT_DISTRIBUTION.record(message.getRecipients().size());
|
||||||
|
|
||||||
if (message.getRecipients().values().stream().anyMatch(r -> message.messageSizeForRecipient(r) > MAX_MESSAGE_SIZE)) {
|
|
||||||
throw new BadRequestException("message payload too large");
|
|
||||||
}
|
|
||||||
return message;
|
return message;
|
||||||
} catch (InvalidMessageException | InvalidVersionException e) {
|
} catch (InvalidMessageException | InvalidVersionException e) {
|
||||||
throw new BadRequestException(e);
|
throw new BadRequestException(e);
|
||||||
|
|
|
@ -9,9 +9,14 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
|
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
|
||||||
|
@ -42,26 +47,82 @@ public class MessageSender {
|
||||||
this.pushNotificationManager = pushNotificationManager;
|
this.pushNotificationManager = pushNotificationManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean online) {
|
/**
|
||||||
final boolean destinationPresent = messagesManager.insert(account.getUuid(),
|
* Sends messages to devices associated with the given destination account. If a destination device has a valid push
|
||||||
device.getId(),
|
* notification token and does not have an active connection to a Signal server, then this method will also send a
|
||||||
online ? message.toBuilder().setEphemeral(true).build() : message);
|
* push notification to that device to announce the availability of new messages.
|
||||||
|
*
|
||||||
|
* @param account the account to which to send messages
|
||||||
|
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||||
|
*/
|
||||||
|
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
|
||||||
|
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
|
||||||
|
.forEach((deviceId, destinationPresent) -> {
|
||||||
|
final Envelope message = messagesByDeviceId.get(deviceId);
|
||||||
|
|
||||||
if (!destinationPresent && !online) {
|
if (!destinationPresent && !message.getEphemeral()) {
|
||||||
try {
|
try {
|
||||||
pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent());
|
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
|
||||||
} catch (final NotPushRegisteredException ignored) {
|
} catch (final NotPushRegisteredException ignored) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Metrics.counter(SEND_COUNTER_NAME,
|
Metrics.counter(SEND_COUNTER_NAME,
|
||||||
CHANNEL_TAG_NAME, getDeliveryChannelName(device),
|
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
|
||||||
EPHEMERAL_TAG_NAME, String.valueOf(online),
|
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
|
||||||
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
|
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
|
||||||
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
|
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
|
||||||
STORY_TAG_NAME, String.valueOf(message.getStory()),
|
STORY_TAG_NAME, String.valueOf(message.getStory()),
|
||||||
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
|
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
|
||||||
.increment();
|
.increment();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
|
||||||
|
* have an active connection to a Signal server, then this method will also send a push notification to that device to
|
||||||
|
* announce the availability of new messages.
|
||||||
|
*
|
||||||
|
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
|
||||||
|
* @param resolvedRecipients a map of recipients to resolved Signal accounts
|
||||||
|
* @param clientTimestamp the time at which the sender reports the message was sent
|
||||||
|
* @param isStory {@code true} if the message is a story or {@code false otherwise}
|
||||||
|
* @param isEphemeral {@code true} if the message should only be delivered to devices with active connections or
|
||||||
|
* {@code false otherwise}
|
||||||
|
* @param isUrgent {@code true} if the message is urgent or {@code false otherwise}
|
||||||
|
*
|
||||||
|
* @return a future that completes when all messages have been inserted into delivery queues
|
||||||
|
*/
|
||||||
|
public CompletableFuture<Void> sendMultiRecipientMessage(final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||||
|
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients,
|
||||||
|
final long clientTimestamp,
|
||||||
|
final boolean isStory,
|
||||||
|
final boolean isEphemeral,
|
||||||
|
final boolean isUrgent) {
|
||||||
|
|
||||||
|
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
|
||||||
|
isStory, isEphemeral, isUrgent)
|
||||||
|
.thenAccept(clientPresenceByAccountAndDevice ->
|
||||||
|
clientPresenceByAccountAndDevice.forEach((account, clientPresenceByDeviceId) ->
|
||||||
|
clientPresenceByDeviceId.forEach((deviceId, clientPresent) -> {
|
||||||
|
if (!clientPresent && !isEphemeral) {
|
||||||
|
try {
|
||||||
|
pushNotificationManager.sendNewMessageNotification(account, deviceId, isUrgent);
|
||||||
|
} catch (final NotPushRegisteredException ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Metrics.counter(SEND_COUNTER_NAME,
|
||||||
|
CHANNEL_TAG_NAME,
|
||||||
|
account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
|
||||||
|
EPHEMERAL_TAG_NAME, String.valueOf(isEphemeral),
|
||||||
|
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
|
||||||
|
URGENT_TAG_NAME, String.valueOf(isUrgent),
|
||||||
|
STORY_TAG_NAME, String.valueOf(isStory),
|
||||||
|
SEALED_SENDER_TAG_NAME, String.valueOf(true))
|
||||||
|
.increment();
|
||||||
|
})))
|
||||||
|
.thenRun(Util.NOOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
|
|
|
@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
|
@ -43,21 +44,21 @@ public class ReceiptSender {
|
||||||
try {
|
try {
|
||||||
accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse(
|
accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse(
|
||||||
destinationAccount -> {
|
destinationAccount -> {
|
||||||
final Envelope.Builder message = Envelope.newBuilder()
|
final Envelope message = Envelope.newBuilder()
|
||||||
.setServerTimestamp(System.currentTimeMillis())
|
.setServerTimestamp(System.currentTimeMillis())
|
||||||
.setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
|
.setSourceServiceId(sourceIdentifier.toServiceIdentifierString())
|
||||||
.setSourceDevice(sourceDeviceId)
|
.setSourceDevice(sourceDeviceId)
|
||||||
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
|
.setDestinationServiceId(destinationIdentifier.toServiceIdentifierString())
|
||||||
.setClientTimestamp(messageId)
|
.setClientTimestamp(messageId)
|
||||||
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
|
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
|
||||||
.setUrgent(false);
|
.setUrgent(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
for (final Device destinationDevice : destinationAccount.getDevices()) {
|
try {
|
||||||
try {
|
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
|
||||||
messageSender.sendMessage(destinationAccount, destinationDevice, message.build(), false);
|
.collect(Collectors.toMap(Device::getId, ignored -> message)));
|
||||||
} catch (final Exception e) {
|
} catch (final Exception e) {
|
||||||
logger.warn("Could not send delivery receipt", e);
|
logger.warn("Could not send delivery receipt", e);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
() -> logger.info("No longer registered: {}", destinationIdentifier)
|
() -> logger.info("No longer registered: {}", destinationIdentifier)
|
||||||
|
|
|
@ -4,8 +4,8 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.storage;
|
package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
|
||||||
import com.google.protobuf.ByteString;
|
import com.google.protobuf.ByteString;
|
||||||
|
import java.util.Base64;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -13,10 +13,10 @@ import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
import org.apache.commons.lang3.ObjectUtils;
|
import org.apache.commons.lang3.ObjectUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.signal.libsignal.protocol.IdentityKey;
|
import org.signal.libsignal.protocol.IdentityKey;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||||
|
@ -115,40 +115,39 @@ public class ChangeNumberManager {
|
||||||
|
|
||||||
private void sendDeviceMessages(final Account account, final List<IncomingMessage> deviceMessages) {
|
private void sendDeviceMessages(final Account account, final List<IncomingMessage> deviceMessages) {
|
||||||
try {
|
try {
|
||||||
deviceMessages.forEach(message ->
|
final long serverTimestamp = System.currentTimeMillis();
|
||||||
sendMessageToSelf(account, account.getDevice(message.destinationDeviceId()), message));
|
|
||||||
} catch (RuntimeException e) {
|
messageSender.sendMessages(account, deviceMessages.stream()
|
||||||
|
.filter(message -> getMessageContent(message).isPresent())
|
||||||
|
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
|
||||||
|
.setType(Envelope.Type.forNumber(message.type()))
|
||||||
|
.setClientTimestamp(serverTimestamp)
|
||||||
|
.setServerTimestamp(serverTimestamp)
|
||||||
|
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
|
||||||
|
.setContent(ByteString.copyFrom(getMessageContent(message).orElseThrow()))
|
||||||
|
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
|
||||||
|
.setSourceDevice(Device.PRIMARY_ID)
|
||||||
|
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
|
||||||
|
.setUrgent(true)
|
||||||
|
.setEphemeral(false)
|
||||||
|
.build())));
|
||||||
|
} catch (final RuntimeException e) {
|
||||||
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
|
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
|
||||||
throw e;
|
throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
private static Optional<byte[]> getMessageContent(final IncomingMessage message) {
|
||||||
void sendMessageToSelf(
|
if (StringUtils.isEmpty(message.content())) {
|
||||||
Account sourceAndDestinationAccount, Optional<Device> destinationDevice, IncomingMessage message) {
|
logger.warn("Message has no content");
|
||||||
Optional<byte[]> contents = MessageController.getMessageContent(message);
|
return Optional.empty();
|
||||||
if (contents.isEmpty()) {
|
|
||||||
logger.debug("empty message contents sending to self, ignoring");
|
|
||||||
return;
|
|
||||||
} else if (destinationDevice.isEmpty()) {
|
|
||||||
logger.debug("destination device not present");
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final long serverTimestamp = System.currentTimeMillis();
|
try {
|
||||||
final Envelope envelope = Envelope.newBuilder()
|
return Optional.of(Base64.getDecoder().decode(message.content()));
|
||||||
.setType(Envelope.Type.forNumber(message.type()))
|
} catch (final IllegalArgumentException e) {
|
||||||
.setClientTimestamp(serverTimestamp)
|
logger.warn("Failed to parse message content", e);
|
||||||
.setServerTimestamp(serverTimestamp)
|
return Optional.empty();
|
||||||
.setDestinationServiceId(
|
}
|
||||||
new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
|
|
||||||
.setContent(ByteString.copyFrom(contents.get()))
|
|
||||||
.setSourceServiceId(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
|
|
||||||
.setSourceDevice(Device.PRIMARY_ID)
|
|
||||||
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
|
|
||||||
.setUrgent(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -203,22 +203,28 @@ public class MessagesCache {
|
||||||
this.unlockQueueScript = unlockQueueScript;
|
this.unlockQueueScript = unlockQueueScript;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean insert(final UUID messageGuid,
|
public CompletableFuture<Boolean> insert(final UUID messageGuid,
|
||||||
final UUID destinationAccountIdentifier,
|
final UUID destinationAccountIdentifier,
|
||||||
final byte destinationDeviceId,
|
final byte destinationDeviceId,
|
||||||
final MessageProtos.Envelope message) {
|
final MessageProtos.Envelope message) {
|
||||||
|
|
||||||
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build();
|
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build();
|
||||||
return insertTimer.record(() -> insertScript.execute(destinationAccountIdentifier, destinationDeviceId, messageWithGuid));
|
final Timer.Sample sample = Timer.start();
|
||||||
|
|
||||||
|
return insertScript.executeAsync(destinationAccountIdentifier, destinationDeviceId, messageWithGuid)
|
||||||
|
.whenComplete((ignored, throwable) -> sample.stop(insertTimer));
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] insertSharedMultiRecipientMessagePayload(
|
public CompletableFuture<byte[]> insertSharedMultiRecipientMessagePayload(
|
||||||
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
|
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
|
||||||
return insertSharedMrmPayloadTimer.record(() -> {
|
|
||||||
final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID());
|
final Timer.Sample sample = Timer.start();
|
||||||
insertMrmScript.execute(sharedMrmKey, sealedSenderMultiRecipientMessage);
|
|
||||||
return sharedMrmKey;
|
final byte[] sharedMrmKey = getSharedMrmKey(UUID.randomUUID());
|
||||||
});
|
|
||||||
|
return insertMrmScript.executeAsync(sharedMrmKey, sealedSenderMultiRecipientMessage)
|
||||||
|
.thenApply(ignored -> sharedMrmKey)
|
||||||
|
.whenComplete((ignored, throwable) -> sample.stop(insertSharedMrmPayloadTimer));
|
||||||
}
|
}
|
||||||
|
|
||||||
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
|
public CompletableFuture<Optional<RemovedMessage>> remove(final UUID destinationUuid, final byte destinationDevice,
|
||||||
|
|
|
@ -12,6 +12,7 @@ import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
import org.whispersystems.textsecuregcm.push.ClientEvent;
|
import org.whispersystems.textsecuregcm.push.ClientEvent;
|
||||||
import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent;
|
import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent;
|
||||||
|
@ -44,7 +45,7 @@ class MessagesCacheInsertScript {
|
||||||
* @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false}
|
* @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false}
|
||||||
* otherwise
|
* otherwise
|
||||||
*/
|
*/
|
||||||
boolean execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
|
CompletableFuture<Boolean> executeAsync(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
|
||||||
assert envelope.hasServerGuid();
|
assert envelope.hasServerGuid();
|
||||||
assert envelope.hasServerTimestamp();
|
assert envelope.hasServerTimestamp();
|
||||||
|
|
||||||
|
@ -62,6 +63,7 @@ class MessagesCacheInsertScript {
|
||||||
NEW_MESSAGE_EVENT_BYTES // eventPayload
|
NEW_MESSAGE_EVENT_BYTES // eventPayload
|
||||||
));
|
));
|
||||||
|
|
||||||
return (boolean) insertScript.executeBinary(keys, args);
|
return insertScript.executeBinaryAsync(keys, args)
|
||||||
|
.thenApply(result -> (boolean) result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,9 +9,11 @@ import io.lettuce.core.ScriptOutputType;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
|
* Inserts the shared multi-recipient message payload into the cache. The list of recipients and views will be set as
|
||||||
|
@ -31,7 +33,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
|
||||||
ScriptOutputType.INTEGER);
|
ScriptOutputType.INTEGER);
|
||||||
}
|
}
|
||||||
|
|
||||||
void execute(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
|
CompletableFuture<Void> executeAsync(final byte[] sharedMrmKey, final SealedSenderMultiRecipientMessage message) {
|
||||||
final List<byte[]> keys = List.of(
|
final List<byte[]> keys = List.of(
|
||||||
sharedMrmKey // sharedMrmKey
|
sharedMrmKey // sharedMrmKey
|
||||||
);
|
);
|
||||||
|
@ -47,6 +49,7 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
script.executeBinary(keys, args);
|
return script.executeBinaryAsync(keys, args)
|
||||||
|
.thenRun(Util.NOOP);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,18 +6,23 @@ package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
import io.micrometer.core.instrument.Counter;
|
import io.micrometer.core.instrument.Counter;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import java.time.Clock;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.concurrent.TimeoutException;
|
import java.util.concurrent.TimeoutException;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
import org.reactivestreams.Publisher;
|
import org.reactivestreams.Publisher;
|
||||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
|
@ -25,6 +30,8 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import reactor.core.observability.micrometer.Micrometer;
|
import reactor.core.observability.micrometer.Micrometer;
|
||||||
|
@ -48,41 +55,120 @@ public class MessagesManager {
|
||||||
private final MessagesCache messagesCache;
|
private final MessagesCache messagesCache;
|
||||||
private final ReportMessageManager reportMessageManager;
|
private final ReportMessageManager reportMessageManager;
|
||||||
private final ExecutorService messageDeletionExecutor;
|
private final ExecutorService messageDeletionExecutor;
|
||||||
|
private final Clock clock;
|
||||||
|
|
||||||
public MessagesManager(
|
public MessagesManager(
|
||||||
final MessagesDynamoDb messagesDynamoDb,
|
final MessagesDynamoDb messagesDynamoDb,
|
||||||
final MessagesCache messagesCache,
|
final MessagesCache messagesCache,
|
||||||
final ReportMessageManager reportMessageManager,
|
final ReportMessageManager reportMessageManager,
|
||||||
final ExecutorService messageDeletionExecutor) {
|
final ExecutorService messageDeletionExecutor,
|
||||||
|
final Clock clock) {
|
||||||
|
|
||||||
this.messagesDynamoDb = messagesDynamoDb;
|
this.messagesDynamoDb = messagesDynamoDb;
|
||||||
this.messagesCache = messagesCache;
|
this.messagesCache = messagesCache;
|
||||||
this.reportMessageManager = reportMessageManager;
|
this.reportMessageManager = reportMessageManager;
|
||||||
this.messageDeletionExecutor = messageDeletionExecutor;
|
this.messageDeletionExecutor = messageDeletionExecutor;
|
||||||
|
this.clock = clock;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inserts a message into a target device's message queue and notifies registered listeners that a new message is
|
* Inserts messages into the message queues for devices associated with the identified account.
|
||||||
* available.
|
|
||||||
*
|
*
|
||||||
* @param destinationUuid the account identifier for the destination queue
|
* @param accountIdentifier the account identifier for the destination queue
|
||||||
* @param destinationDeviceId the device ID for the destination queue
|
* @param messagesByDeviceId a map of device IDs to messages
|
||||||
* @param message the message to insert into the queue
|
|
||||||
*
|
*
|
||||||
* @return {@code true} if the destination device is "present" (i.e. has an active event listener) or {@code false}
|
* @return a map of device IDs to a device's presence state (i.e. if the device has an active event listener)
|
||||||
* otherwise
|
|
||||||
*
|
*
|
||||||
* @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager
|
* @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager
|
||||||
*/
|
*/
|
||||||
public boolean insert(final UUID destinationUuid, final byte destinationDeviceId, final Envelope message) {
|
public Map<Byte, Boolean> insert(final UUID accountIdentifier, final Map<Byte, Envelope> messagesByDeviceId) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
return insertAsync(accountIdentifier, messagesByDeviceId).join();
|
||||||
|
}
|
||||||
|
|
||||||
final boolean destinationPresent = messagesCache.insert(messageGuid, destinationUuid, destinationDeviceId, message);
|
private CompletableFuture<Map<Byte, Boolean>> insertAsync(final UUID accountIdentifier, final Map<Byte, Envelope> messagesByDeviceId) {
|
||||||
|
final Map<Byte, Boolean> devicePresenceById = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
if (message.hasSourceServiceId() && !destinationUuid.toString().equals(message.getSourceServiceId())) {
|
return CompletableFuture.allOf(messagesByDeviceId.entrySet().stream()
|
||||||
reportMessageManager.store(message.getSourceServiceId(), messageGuid);
|
.map(deviceIdAndMessage -> {
|
||||||
}
|
final byte deviceId = deviceIdAndMessage.getKey();
|
||||||
|
final Envelope message = deviceIdAndMessage.getValue();
|
||||||
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
|
|
||||||
return destinationPresent;
|
return messagesCache.insert(messageGuid, accountIdentifier, deviceId, message)
|
||||||
|
.thenAccept(present -> {
|
||||||
|
if (message.hasSourceServiceId() && !accountIdentifier.toString()
|
||||||
|
.equals(message.getSourceServiceId())) {
|
||||||
|
// Note that this is an asynchronous, best-effort, fire-and-forget operation
|
||||||
|
reportMessageManager.store(message.getSourceServiceId(), messageGuid);
|
||||||
|
}
|
||||||
|
|
||||||
|
devicePresenceById.put(deviceId, present);
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.toArray(CompletableFuture[]::new))
|
||||||
|
.thenApply(ignored -> devicePresenceById);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inserts messages into the message queues for devices associated with the identified accounts.
|
||||||
|
*
|
||||||
|
* @param multiRecipientMessage the multi-recipient message to insert into destination queues
|
||||||
|
* @param resolvedRecipients a map of multi-recipient message {@code Recipient} entities to their corresponding
|
||||||
|
* Signal accounts; messages will not be delivered to unresolved recipients
|
||||||
|
* @param clientTimestamp the timestamp for the message as reported by the sending party
|
||||||
|
* @param isStory {@code true} if the given message is a story or {@code false} otherwise
|
||||||
|
* @param isEphemeral {@code true} if the given message should only be delivered to devices with active
|
||||||
|
* connections to a Signal server or {@code false} otherwise
|
||||||
|
* @param isUrgent {@code true} if the given message is urgent or {@code false} otherwise
|
||||||
|
*
|
||||||
|
* @return a map of accounts to maps of device IDs to a device's presence state (i.e. if the device has an active
|
||||||
|
* event listener)
|
||||||
|
*
|
||||||
|
* @see org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager
|
||||||
|
*/
|
||||||
|
public CompletableFuture<Map<Account, Map<Byte, Boolean>>> insertMultiRecipientMessage(
|
||||||
|
final SealedSenderMultiRecipientMessage multiRecipientMessage,
|
||||||
|
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients,
|
||||||
|
final long clientTimestamp,
|
||||||
|
final boolean isStory,
|
||||||
|
final boolean isEphemeral,
|
||||||
|
final boolean isUrgent) {
|
||||||
|
|
||||||
|
final long serverTimestamp = clock.millis();
|
||||||
|
|
||||||
|
return insertSharedMultiRecipientMessagePayload(multiRecipientMessage)
|
||||||
|
.thenCompose(sharedMrmKey -> {
|
||||||
|
final Envelope prototypeMessage = Envelope.newBuilder()
|
||||||
|
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
|
||||||
|
.setClientTimestamp(clientTimestamp == 0 ? serverTimestamp : clientTimestamp)
|
||||||
|
.setServerTimestamp(serverTimestamp)
|
||||||
|
.setStory(isStory)
|
||||||
|
.setEphemeral(isEphemeral)
|
||||||
|
.setUrgent(isUrgent)
|
||||||
|
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Map<Account, Map<Byte, Boolean>> clientPresenceByAccountAndDevice = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
return CompletableFuture.allOf(multiRecipientMessage.getRecipients().entrySet().stream()
|
||||||
|
.filter(serviceIdAndRecipient -> resolvedRecipients.containsKey(serviceIdAndRecipient.getValue()))
|
||||||
|
.map(serviceIdAndRecipient -> {
|
||||||
|
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey());
|
||||||
|
final SealedSenderMultiRecipientMessage.Recipient recipient = serviceIdAndRecipient.getValue();
|
||||||
|
final byte[] devices = recipient.getDevices();
|
||||||
|
|
||||||
|
return insertAsync(resolvedRecipients.get(recipient).getIdentifier(IdentityType.ACI),
|
||||||
|
IntStream.range(0, devices.length).mapToObj(i -> devices[i])
|
||||||
|
.collect(Collectors.toMap(deviceId -> deviceId, deviceId -> prototypeMessage.toBuilder()
|
||||||
|
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
|
||||||
|
.build())))
|
||||||
|
.thenAccept(clientPresenceByDeviceId ->
|
||||||
|
clientPresenceByAccountAndDevice.put(resolvedRecipients.get(recipient),
|
||||||
|
clientPresenceByDeviceId));
|
||||||
|
})
|
||||||
|
.toArray(CompletableFuture[]::new))
|
||||||
|
.thenApply(ignored -> clientPresenceByAccountAndDevice);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public CompletableFuture<Boolean> mayHavePersistedMessages(final UUID destinationUuid, final Device destinationDevice) {
|
public CompletableFuture<Boolean> mayHavePersistedMessages(final UUID destinationUuid, final Device destinationDevice) {
|
||||||
|
@ -217,7 +303,7 @@ public class MessagesManager {
|
||||||
* @return a key where the shared data is stored
|
* @return a key where the shared data is stored
|
||||||
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
|
* @see MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript
|
||||||
*/
|
*/
|
||||||
public byte[] insertSharedMultiRecipientMessagePayload(
|
private CompletableFuture<byte[]> insertSharedMultiRecipientMessagePayload(
|
||||||
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
|
final SealedSenderMultiRecipientMessage sealedSenderMultiRecipientMessage) {
|
||||||
return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage);
|
return messagesCache.insertSharedMultiRecipientMessagePayload(sealedSenderMultiRecipientMessage);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ package org.whispersystems.textsecuregcm.storage;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.Timer;
|
import io.micrometer.core.instrument.Timer;
|
||||||
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
import org.whispersystems.textsecuregcm.util.AttributeValues;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
|
||||||
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
|
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
|
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
|
||||||
|
@ -11,6 +13,7 @@ import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||||
|
|
||||||
|
@ -20,6 +23,7 @@ public class ReportMessageDynamoDb {
|
||||||
static final String ATTR_TTL = "E";
|
static final String ATTR_TTL = "E";
|
||||||
|
|
||||||
private final DynamoDbClient db;
|
private final DynamoDbClient db;
|
||||||
|
private final DynamoDbAsyncClient dynamoDbAsyncClient;
|
||||||
private final String tableName;
|
private final String tableName;
|
||||||
private final Duration ttl;
|
private final Duration ttl;
|
||||||
|
|
||||||
|
@ -30,20 +34,26 @@ public class ReportMessageDynamoDb {
|
||||||
.distributionStatisticExpiry(Duration.ofDays(1))
|
.distributionStatisticExpiry(Duration.ofDays(1))
|
||||||
.register(Metrics.globalRegistry);
|
.register(Metrics.globalRegistry);
|
||||||
|
|
||||||
public ReportMessageDynamoDb(final DynamoDbClient dynamoDB, final String tableName, final Duration ttl) {
|
public ReportMessageDynamoDb(final DynamoDbClient dynamoDB,
|
||||||
|
final DynamoDbAsyncClient dynamoDbAsyncClient,
|
||||||
|
final String tableName,
|
||||||
|
final Duration ttl) {
|
||||||
|
|
||||||
this.db = dynamoDB;
|
this.db = dynamoDB;
|
||||||
|
this.dynamoDbAsyncClient = dynamoDbAsyncClient;
|
||||||
this.tableName = tableName;
|
this.tableName = tableName;
|
||||||
this.ttl = ttl;
|
this.ttl = ttl;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void store(byte[] hash) {
|
public CompletableFuture<Void> store(byte[] hash) {
|
||||||
db.putItem(PutItemRequest.builder()
|
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
|
||||||
.tableName(tableName)
|
.tableName(tableName)
|
||||||
.item(Map.of(
|
.item(Map.of(
|
||||||
KEY_HASH, AttributeValues.fromByteArray(hash),
|
KEY_HASH, AttributeValues.fromByteArray(hash),
|
||||||
ATTR_TTL, AttributeValues.fromLong(Instant.now().plus(ttl).getEpochSecond())
|
ATTR_TTL, AttributeValues.fromLong(Instant.now().plus(ttl).getEpochSecond())
|
||||||
))
|
))
|
||||||
.build());
|
.build())
|
||||||
|
.thenRun(Util.NOOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean remove(byte[] hash) {
|
public boolean remove(byte[] hash) {
|
||||||
|
|
|
@ -54,11 +54,8 @@ public class ReportMessageManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void store(String sourceAci, UUID messageGuid) {
|
public void store(String sourceAci, UUID messageGuid) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Objects.requireNonNull(sourceAci);
|
reportMessageDynamoDb.store(hash(messageGuid, Objects.requireNonNull(sourceAci)));
|
||||||
|
|
||||||
reportMessageDynamoDb.store(hash(messageGuid, sourceAci));
|
|
||||||
} catch (final Exception e) {
|
} catch (final Exception e) {
|
||||||
logger.warn("Failed to store hash", e);
|
logger.warn("Failed to store hash", e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,13 +22,15 @@ public class DestinationDeviceValidator {
|
||||||
/**
|
/**
|
||||||
* @see #validateRegistrationIds(Account, Stream, boolean)
|
* @see #validateRegistrationIds(Account, Stream, boolean)
|
||||||
*/
|
*/
|
||||||
public static <T> void validateRegistrationIds(final Account account, final Collection<T> messages,
|
public static <T> void validateRegistrationIds(final Account account,
|
||||||
Function<T, Byte> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity)
|
final Collection<T> messages,
|
||||||
throws StaleDevicesException {
|
Function<T, Byte> getDeviceId,
|
||||||
|
Function<T, Integer> getRegistrationId,
|
||||||
|
boolean usePhoneNumberIdentity) throws StaleDevicesException {
|
||||||
|
|
||||||
validateRegistrationIds(account,
|
validateRegistrationIds(account,
|
||||||
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
|
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
|
||||||
usePhoneNumberIdentity);
|
usePhoneNumberIdentity);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -217,13 +217,13 @@ record CommandDependencies(
|
||||||
MessagesCache messagesCache = new MessagesCache(messagesCluster,
|
MessagesCache messagesCache = new MessagesCache(messagesCluster,
|
||||||
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
|
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
|
||||||
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
|
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
|
||||||
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient,
|
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
|
||||||
configuration.getDynamoDbTables().getReportMessage().getTableName(),
|
configuration.getDynamoDbTables().getReportMessage().getTableName(),
|
||||||
configuration.getReportMessageConfiguration().getReportTtl());
|
configuration.getReportMessageConfiguration().getReportTtl());
|
||||||
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
|
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
|
||||||
configuration.getReportMessageConfiguration().getCounterTtl());
|
configuration.getReportMessageConfiguration().getCounterTtl());
|
||||||
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
|
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
|
||||||
reportMessageManager, messageDeletionExecutor);
|
reportMessageManager, messageDeletionExecutor, Clock.systemUTC());
|
||||||
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
|
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
|
||||||
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
|
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
|
||||||
ClientPublicKeysManager clientPublicKeysManager =
|
ClientPublicKeysManager clientPublicKeysManager =
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -10,23 +10,25 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||||
import static org.mockito.ArgumentMatchers.anyByte;
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.protobuf.ByteString;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import org.apache.commons.lang3.RandomStringUtils;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||||
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
|
@ -49,17 +51,21 @@ class MessageSenderTest {
|
||||||
|
|
||||||
@CartesianTest
|
@CartesianTest
|
||||||
void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||||
@CartesianTest.Values(booleans = {true, false}) final boolean onlineMessage,
|
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||||
|
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||||
|
|
||||||
final boolean expectPushNotificationAttempt = !clientPresent && !onlineMessage;
|
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||||
|
|
||||||
final UUID accountIdentifier = UUID.randomUUID();
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
final byte deviceId = Device.PRIMARY_ID;
|
final byte deviceId = Device.PRIMARY_ID;
|
||||||
|
|
||||||
final Account account = mock(Account.class);
|
final Account account = mock(Account.class);
|
||||||
final Device device = mock(Device.class);
|
final Device device = mock(Device.class);
|
||||||
final MessageProtos.Envelope message = generateRandomMessage();
|
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder()
|
||||||
|
.setEphemeral(ephemeral)
|
||||||
|
.setUrgent(urgent)
|
||||||
|
.build();
|
||||||
|
|
||||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||||
|
@ -72,18 +78,61 @@ class MessageSenderTest {
|
||||||
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
|
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
|
||||||
}
|
}
|
||||||
|
|
||||||
when(messagesManager.insert(eq(accountIdentifier), eq(deviceId), any())).thenReturn(clientPresent);
|
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
|
||||||
|
|
||||||
assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, onlineMessage));
|
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
|
||||||
|
|
||||||
final MessageProtos.Envelope expectedMessage = onlineMessage
|
final MessageProtos.Envelope expectedMessage = ephemeral
|
||||||
? message.toBuilder().setEphemeral(true).build()
|
? message.toBuilder().setEphemeral(true).build()
|
||||||
: message.toBuilder().build();
|
: message.toBuilder().build();
|
||||||
|
|
||||||
verify(messagesManager).insert(accountIdentifier, deviceId, expectedMessage);
|
verify(messagesManager).insert(accountIdentifier, Map.of(deviceId, expectedMessage));
|
||||||
|
|
||||||
if (expectPushNotificationAttempt) {
|
if (expectPushNotificationAttempt) {
|
||||||
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, expectedMessage.getUrgent());
|
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
|
||||||
|
} else {
|
||||||
|
verifyNoInteractions(pushNotificationManager);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@CartesianTest
|
||||||
|
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||||
|
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||||
|
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||||
|
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||||
|
|
||||||
|
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||||
|
|
||||||
|
final UUID accountIdentifier = UUID.randomUUID();
|
||||||
|
final byte deviceId = Device.PRIMARY_ID;
|
||||||
|
|
||||||
|
final Account account = mock(Account.class);
|
||||||
|
final Device device = mock(Device.class);
|
||||||
|
|
||||||
|
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||||
|
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||||
|
when(device.getId()).thenReturn(deviceId);
|
||||||
|
|
||||||
|
if (hasPushToken) {
|
||||||
|
when(device.getApnId()).thenReturn("apns-token");
|
||||||
|
} else {
|
||||||
|
doThrow(NotPushRegisteredException.class)
|
||||||
|
.when(pushNotificationManager).sendNewMessageNotification(any(), anyByte(), anyBoolean());
|
||||||
|
}
|
||||||
|
|
||||||
|
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
|
||||||
|
Collections.emptyMap(),
|
||||||
|
System.currentTimeMillis(),
|
||||||
|
false,
|
||||||
|
ephemeral,
|
||||||
|
urgent)
|
||||||
|
.join());
|
||||||
|
|
||||||
|
if (expectPushNotificationAttempt) {
|
||||||
|
verify(pushNotificationManager).sendNewMessageNotification(account, deviceId, urgent);
|
||||||
} else {
|
} else {
|
||||||
verifyNoInteractions(pushNotificationManager);
|
verifyNoInteractions(pushNotificationManager);
|
||||||
}
|
}
|
||||||
|
@ -123,14 +172,4 @@ class MessageSenderTest {
|
||||||
|
|
||||||
return arguments;
|
return arguments;
|
||||||
}
|
}
|
||||||
|
|
||||||
private MessageProtos.Envelope generateRandomMessage() {
|
|
||||||
return MessageProtos.Envelope.newBuilder()
|
|
||||||
.setClientTimestamp(System.currentTimeMillis())
|
|
||||||
.setServerTimestamp(System.currentTimeMillis())
|
|
||||||
.setContent(ByteString.copyFromUtf8(RandomStringUtils.secure().nextAlphanumeric(256)))
|
|
||||||
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
|
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -104,7 +105,7 @@ public class ChangeNumberManagerTest {
|
||||||
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
|
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
|
||||||
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
|
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
|
||||||
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
|
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
|
||||||
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
|
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -118,7 +119,7 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
|
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
|
||||||
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
|
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
|
||||||
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
|
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -155,10 +156,15 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
|
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
|
||||||
|
|
||||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
ArgumentCaptor.forClass(Map.class);
|
||||||
|
|
||||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(1, envelopeCaptor.getValue().size());
|
||||||
|
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||||
|
|
||||||
|
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||||
|
|
||||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||||
|
@ -203,10 +209,15 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||||
|
|
||||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
ArgumentCaptor.forClass(Map.class);
|
||||||
|
|
||||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(1, envelopeCaptor.getValue().size());
|
||||||
|
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||||
|
|
||||||
|
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||||
|
|
||||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||||
|
@ -249,10 +260,15 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||||
|
|
||||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
ArgumentCaptor.forClass(Map.class);
|
||||||
|
|
||||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(1, envelopeCaptor.getValue().size());
|
||||||
|
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||||
|
|
||||||
|
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||||
|
|
||||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||||
|
@ -291,10 +307,15 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
|
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
|
||||||
|
|
||||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
ArgumentCaptor.forClass(Map.class);
|
||||||
|
|
||||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(1, envelopeCaptor.getValue().size());
|
||||||
|
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||||
|
|
||||||
|
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||||
|
|
||||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||||
|
@ -335,10 +356,15 @@ public class ChangeNumberManagerTest {
|
||||||
|
|
||||||
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
|
||||||
|
|
||||||
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
|
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||||
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
|
ArgumentCaptor.forClass(Map.class);
|
||||||
|
|
||||||
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
|
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(1, envelopeCaptor.getValue().size());
|
||||||
|
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||||
|
|
||||||
|
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
|
||||||
|
|
||||||
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
|
||||||
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
|
||||||
|
|
|
@ -84,7 +84,7 @@ class MessagePersisterIntegrationTest {
|
||||||
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
|
messageDeliveryScheduler, messageDeletionExecutorService, Clock.systemUTC());
|
||||||
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
|
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
|
||||||
messageDeletionExecutorService);
|
messageDeletionExecutorService, Clock.systemUTC());
|
||||||
|
|
||||||
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor();
|
||||||
asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor();
|
asyncOperationQueueingExecutor = Executors.newSingleThreadExecutor();
|
||||||
|
@ -143,7 +143,7 @@ class MessagePersisterIntegrationTest {
|
||||||
|
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message);
|
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message).join();
|
||||||
expectedMessages.add(message);
|
expectedMessages.add(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -358,7 +358,7 @@ class MessagePersisterTest {
|
||||||
.setServerGuid(messageGuid.toString())
|
.setServerGuid(messageGuid.toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
|
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope).join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class MessagesCacheGetItemsScriptTest {
|
||||||
.setServerGuid(serverGuid)
|
.setServerGuid(serverGuid)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||||
|
|
||||||
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
|
final MessagesCacheGetItemsScript getItemsScript = new MessagesCacheGetItemsScript(
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
|
@ -41,7 +41,7 @@ class MessagesCacheInsertScriptTest {
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
.setServerGuid(UUID.randomUUID().toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||||
|
|
||||||
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
|
assertEquals(List.of(envelope1), getStoredMessages(destinationUuid, deviceId));
|
||||||
|
|
||||||
|
@ -50,11 +50,11 @@ class MessagesCacheInsertScriptTest {
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
.setServerGuid(UUID.randomUUID().toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope2);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope2);
|
||||||
|
|
||||||
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
|
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId));
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||||
|
|
||||||
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
|
assertEquals(List.of(envelope1, envelope2), getStoredMessages(destinationUuid, deviceId),
|
||||||
"Messages with same GUID should be deduplicated");
|
"Messages with same GUID should be deduplicated");
|
||||||
|
@ -89,10 +89,10 @@ class MessagesCacheInsertScriptTest {
|
||||||
final MessagesCacheInsertScript insertScript =
|
final MessagesCacheInsertScript insertScript =
|
||||||
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
||||||
assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
assertFalse(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||||
.setServerTimestamp(Instant.now().getEpochSecond())
|
.setServerTimestamp(Instant.now().getEpochSecond())
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
.setServerGuid(UUID.randomUUID().toString())
|
||||||
.build()));
|
.build()).join());
|
||||||
|
|
||||||
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
|
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
|
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
|
||||||
|
@ -100,9 +100,9 @@ class MessagesCacheInsertScriptTest {
|
||||||
pubSubClusterConnection.usePubSubConnection(connection ->
|
pubSubClusterConnection.usePubSubConnection(connection ->
|
||||||
connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId)));
|
connection.sync().ssubscribe(WebSocketConnectionEventManager.getClientEventChannel(destinationUuid, deviceId)));
|
||||||
|
|
||||||
assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
assertTrue(insertScript.executeAsync(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
|
||||||
.setServerTimestamp(Instant.now().getEpochSecond())
|
.setServerTimestamp(Instant.now().getEpochSecond())
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
.setServerGuid(UUID.randomUUID().toString())
|
||||||
.build()));
|
.build()).join());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,9 @@
|
||||||
package org.whispersystems.textsecuregcm.storage;
|
package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
import io.lettuce.core.RedisCommandExecutionException;
|
import io.lettuce.core.RedisCommandExecutionException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -14,8 +16,10 @@ import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletionException;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
import io.lettuce.core.RedisException;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
@ -39,8 +43,8 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
||||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||||
insertMrmScript.execute(sharedMrmKey,
|
insertMrmScript.executeAsync(sharedMrmKey,
|
||||||
MessagesCacheTest.generateRandomMrmMessage(destinations));
|
MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
|
||||||
|
|
||||||
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
|
final int totalDevices = destinations.values().stream().mapToInt(List::size).sum();
|
||||||
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
final long hashFieldCount = REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
||||||
|
@ -82,15 +86,17 @@ class MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScriptTest {
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
||||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||||
insertMrmScript.execute(sharedMrmKey,
|
insertMrmScript.executeAsync(sharedMrmKey,
|
||||||
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID));
|
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()), Device.PRIMARY_ID)).join();
|
||||||
|
|
||||||
final RedisCommandExecutionException e = assertThrows(RedisCommandExecutionException.class,
|
final CompletionException completionException = assertThrows(CompletionException.class,
|
||||||
() -> insertMrmScript.execute(sharedMrmKey,
|
() -> insertMrmScript.executeAsync(sharedMrmKey,
|
||||||
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
|
MessagesCacheTest.generateRandomMrmMessage(new AciServiceIdentifier(UUID.randomUUID()),
|
||||||
Device.PRIMARY_ID)));
|
Device.PRIMARY_ID)).join());
|
||||||
|
|
||||||
assertEquals(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS, e.getMessage());
|
assertInstanceOf(RedisException.class, completionException.getCause());
|
||||||
|
assertTrue(completionException.getCause().getMessage()
|
||||||
|
.contains(MessagesCacheInsertSharedMultiRecipientPayloadAndViewsScript.ERROR_KEY_EXISTS));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ class MessagesCacheRemoveByGuidScriptTest {
|
||||||
.setServerGuid(serverGuid.toString())
|
.setServerGuid(serverGuid.toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||||
|
|
||||||
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
|
final MessagesCacheRemoveByGuidScript removeByGuidScript = new MessagesCacheRemoveByGuidScript(
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
|
@ -35,7 +35,7 @@ class MessagesCacheRemoveQueueScriptTest {
|
||||||
.setServerGuid(UUID.randomUUID().toString())
|
.setServerGuid(UUID.randomUUID().toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
insertScript.execute(destinationUuid, deviceId, envelope1);
|
insertScript.executeAsync(destinationUuid, deviceId, envelope1);
|
||||||
|
|
||||||
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
|
final MessagesCacheRemoveQueueScript removeScript = new MessagesCacheRemoveQueueScript(
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
|
@ -41,8 +41,7 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
||||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||||
insertMrmScript.execute(sharedMrmKey,
|
insertMrmScript.executeAsync(sharedMrmKey, MessagesCacheTest.generateRandomMrmMessage(destinations)).join();
|
||||||
MessagesCacheTest.generateRandomMrmMessage(destinations));
|
|
||||||
|
|
||||||
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
|
final MessagesCacheRemoveRecipientViewFromMrmDataScript removeRecipientViewFromMrmDataScript = new MessagesCacheRemoveRecipientViewFromMrmDataScript(
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
@ -103,8 +102,8 @@ class MessagesCacheRemoveRecipientViewFromMrmDataScriptTest {
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
REDIS_CLUSTER_EXTENSION.getRedisCluster());
|
||||||
|
|
||||||
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
final byte[] sharedMrmKey = MessagesCache.getSharedMrmKey(UUID.randomUUID());
|
||||||
insertMrmScript.execute(sharedMrmKey,
|
insertMrmScript.executeAsync(sharedMrmKey,
|
||||||
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId));
|
MessagesCacheTest.generateRandomMrmMessage(serviceIdentifier, deviceId)).join();
|
||||||
|
|
||||||
sharedMrmKeys.add(sharedMrmKey);
|
sharedMrmKeys.add(sharedMrmKey);
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,7 +122,7 @@ class MessagesCacheTest {
|
||||||
void testInsert(final boolean sealedSender) {
|
void testInsert(final boolean sealedSender) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
assertDoesNotThrow(() -> messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
generateRandomMessage(messageGuid, sealedSender)));
|
generateRandomMessage(messageGuid, sealedSender))).join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -130,8 +130,8 @@ class MessagesCacheTest {
|
||||||
final UUID duplicateGuid = UUID.randomUUID();
|
final UUID duplicateGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
|
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
|
||||||
|
|
||||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
|
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
|
||||||
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
|
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage).join();
|
||||||
|
|
||||||
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
|
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
|
||||||
.count()
|
.count()
|
||||||
|
@ -149,7 +149,7 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
|
final Optional<RemovedMessage> maybeRemovedMessage = messagesCache.remove(DESTINATION_UUID,
|
||||||
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
|
DESTINATION_DEVICE_ID, messageGuid).get(5, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
@ -175,12 +175,12 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
for (final MessageProtos.Envelope message : messagesToRemove) {
|
for (final MessageProtos.Envelope message : messagesToRemove) {
|
||||||
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
message);
|
message).join();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (final MessageProtos.Envelope message : messagesToPreserve) {
|
for (final MessageProtos.Envelope message : messagesToPreserve) {
|
||||||
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
messagesCache.insert(UUID.fromString(message.getServerGuid()), DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
message);
|
message).join();
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
final List<RemovedMessage> removedMessages = messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
|
@ -197,7 +197,7 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
|
|
||||||
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
|
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
|
|
||||||
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
|
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
|
||||||
}
|
}
|
||||||
|
@ -223,7 +223,7 @@ class MessagesCacheTest {
|
||||||
for (int i = 0; i < messageCount; i++) {
|
for (int i = 0; i < messageCount; i++) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
assertEquals(expectedOldestTimestamp,
|
assertEquals(expectedOldestTimestamp,
|
||||||
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
|
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
|
||||||
expectedMessages.add(message);
|
expectedMessages.add(message);
|
||||||
|
@ -248,7 +248,7 @@ class MessagesCacheTest {
|
||||||
for (int i = 0; i < messageCount; i++) {
|
for (int i = 0; i < messageCount; i++) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
expectedMessages.add(message);
|
expectedMessages.add(message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,7 +262,7 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
final UUID message1Guid = UUID.randomUUID();
|
final UUID message1Guid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender);
|
final MessageProtos.Envelope message1 = generateRandomMessage(message1Guid, sealedSender);
|
||||||
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1);
|
messagesCache.insert(message1Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message1).join();
|
||||||
final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
final List<MessageProtos.Envelope> get1 = get(DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
1);
|
1);
|
||||||
assertEquals(List.of(message1), get1);
|
assertEquals(List.of(message1), get1);
|
||||||
|
@ -272,7 +272,7 @@ class MessagesCacheTest {
|
||||||
final UUID message2Guid = UUID.randomUUID();
|
final UUID message2Guid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender);
|
final MessageProtos.Envelope message2 = generateRandomMessage(message2Guid, sealedSender);
|
||||||
|
|
||||||
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2);
|
messagesCache.insert(message2Guid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message2).join();
|
||||||
|
|
||||||
assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1));
|
assertEquals(List.of(message2), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, 1));
|
||||||
}
|
}
|
||||||
|
@ -287,7 +287,7 @@ class MessagesCacheTest {
|
||||||
for (int i = 0; i < messageCount; i++) {
|
for (int i = 0; i < messageCount; i++) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true);
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message).join();
|
||||||
|
|
||||||
expectedMessages.add(message);
|
expectedMessages.add(message);
|
||||||
}
|
}
|
||||||
|
@ -295,7 +295,7 @@ class MessagesCacheTest {
|
||||||
final UUID ephemeralMessageGuid = UUID.randomUUID();
|
final UUID ephemeralMessageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true)
|
final MessageProtos.Envelope ephemeralMessage = generateRandomMessage(ephemeralMessageGuid, true)
|
||||||
.toBuilder().setEphemeral(true).build();
|
.toBuilder().setEphemeral(true).build();
|
||||||
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage);
|
messagesCache.insert(ephemeralMessageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, ephemeralMessage).join();
|
||||||
|
|
||||||
final Clock cacheClock;
|
final Clock cacheClock;
|
||||||
if (expectStale) {
|
if (expectStale) {
|
||||||
|
@ -352,7 +352,7 @@ class MessagesCacheTest {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -372,7 +372,7 @@ class MessagesCacheTest {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message);
|
messagesCache.insert(messageGuid, DESTINATION_UUID, deviceId, message).join();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,7 +404,7 @@ class MessagesCacheTest {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||||
generateRandomMessage(messageGuid, sealedSender));
|
generateRandomMessage(messageGuid, sealedSender)).join();
|
||||||
final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
|
final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
|
||||||
|
|
||||||
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
|
assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
|
||||||
|
@ -427,7 +427,7 @@ class MessagesCacheTest {
|
||||||
|
|
||||||
final byte[] sharedMrmDataKey;
|
final byte[] sharedMrmDataKey;
|
||||||
if (sharedMrmKeyPresent) {
|
if (sharedMrmKeyPresent) {
|
||||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
|
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
|
||||||
} else {
|
} else {
|
||||||
sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8);
|
sharedMrmDataKey = "{1}".getBytes(StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
|
@ -440,7 +440,7 @@ class MessagesCacheTest {
|
||||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
||||||
.clearContent()
|
.clearContent()
|
||||||
.build();
|
.build();
|
||||||
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message);
|
messagesCache.insert(guid, destinationServiceId.uuid(), deviceId, message).join();
|
||||||
|
|
||||||
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
assertEquals(sharedMrmKeyPresent ? 1 : 0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster()
|
||||||
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
|
.withBinaryCluster(conn -> conn.sync().exists(sharedMrmDataKey)));
|
||||||
|
@ -487,13 +487,13 @@ class MessagesCacheTest {
|
||||||
final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
|
final MessageProtos.Envelope message = generateRandomMessage(messageGuid,
|
||||||
new AciServiceIdentifier(destinationUuid), true);
|
new AciServiceIdentifier(destinationUuid), true);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, destinationUuid, deviceId, message);
|
messagesCache.insert(messageGuid, destinationUuid, deviceId, message).join();
|
||||||
|
|
||||||
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
|
final SealedSenderMultiRecipientMessage mrm = generateRandomMrmMessage(destinationServiceId, deviceId);
|
||||||
|
|
||||||
final byte[] sharedMrmDataKey;
|
final byte[] sharedMrmDataKey;
|
||||||
if (sharedMrmKeyPresent) {
|
if (sharedMrmKeyPresent) {
|
||||||
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm);
|
sharedMrmDataKey = messagesCache.insertSharedMultiRecipientMessagePayload(mrm).join();
|
||||||
} else {
|
} else {
|
||||||
sharedMrmDataKey = new byte[]{1};
|
sharedMrmDataKey = new byte[]{1};
|
||||||
}
|
}
|
||||||
|
@ -505,7 +505,7 @@ class MessagesCacheTest {
|
||||||
.clearContent()
|
.clearContent()
|
||||||
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
.setSharedMrmKey(ByteString.copyFrom(sharedMrmDataKey))
|
||||||
.build();
|
.build();
|
||||||
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage);
|
messagesCache.insert(mrmMessageGuid, destinationUuid, deviceId, mrmMessage).join();
|
||||||
|
|
||||||
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);
|
final List<MessageProtos.Envelope> messages = messagesCache.getMessagesToPersist(destinationUuid, deviceId, 100);
|
||||||
|
|
||||||
|
|
|
@ -7,22 +7,42 @@ package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.verifyNoInteractions;
|
import static org.mockito.Mockito.verifyNoInteractions;
|
||||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.CsvSource;
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
|
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||||
|
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||||
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||||
|
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
class MessagesManagerTest {
|
class MessagesManagerTest {
|
||||||
|
@ -31,8 +51,15 @@ class MessagesManagerTest {
|
||||||
private final MessagesCache messagesCache = mock(MessagesCache.class);
|
private final MessagesCache messagesCache = mock(MessagesCache.class);
|
||||||
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
|
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
|
||||||
|
|
||||||
|
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
|
||||||
|
|
||||||
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
|
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
|
||||||
reportMessageManager, Executors.newSingleThreadExecutor());
|
reportMessageManager, Executors.newSingleThreadExecutor(), CLOCK);
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
when(messagesCache.insert(any(), any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(true));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void insert() {
|
void insert() {
|
||||||
|
@ -43,7 +70,7 @@ class MessagesManagerTest {
|
||||||
|
|
||||||
final UUID destinationUuid = UUID.randomUUID();
|
final UUID destinationUuid = UUID.randomUUID();
|
||||||
|
|
||||||
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message);
|
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, message));
|
||||||
|
|
||||||
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
|
verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class));
|
||||||
|
|
||||||
|
@ -51,11 +78,113 @@ class MessagesManagerTest {
|
||||||
.setSourceServiceId(destinationUuid.toString())
|
.setSourceServiceId(destinationUuid.toString())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage);
|
messagesManager.insert(destinationUuid, Map.of(Device.PRIMARY_ID, syncMessage));
|
||||||
|
|
||||||
verifyNoMoreInteractions(reportMessageManager);
|
verifyNoMoreInteractions(reportMessageManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void insertMultiRecipientMessage() throws InvalidMessageException, InvalidVersionException {
|
||||||
|
final ServiceIdentifier singleDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||||
|
final ServiceIdentifier singleDeviceAccountPniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
|
||||||
|
final ServiceIdentifier multiDeviceAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||||
|
final ServiceIdentifier unresolvedAccountAciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||||
|
|
||||||
|
final Account singleDeviceAccount = mock(Account.class);
|
||||||
|
final Account multiDeviceAccount = mock(Account.class);
|
||||||
|
|
||||||
|
when(singleDeviceAccount.getIdentifier(IdentityType.ACI))
|
||||||
|
.thenReturn(singleDeviceAccountAciServiceIdentifier.uuid());
|
||||||
|
|
||||||
|
when(multiDeviceAccount.getIdentifier(IdentityType.ACI))
|
||||||
|
.thenReturn(multiDeviceAccountAciServiceIdentifier.uuid());
|
||||||
|
|
||||||
|
final byte[] multiRecipientMessageBytes = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(
|
||||||
|
new TestRecipient(singleDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]),
|
||||||
|
new TestRecipient(multiDeviceAccountAciServiceIdentifier, Device.PRIMARY_ID, 2, new byte[48]),
|
||||||
|
new TestRecipient(multiDeviceAccountAciServiceIdentifier, (byte) (Device.PRIMARY_ID + 1), 3, new byte[48]),
|
||||||
|
new TestRecipient(unresolvedAccountAciServiceIdentifier, Device.PRIMARY_ID, 4, new byte[48]),
|
||||||
|
new TestRecipient(singleDeviceAccountPniServiceIdentifier, Device.PRIMARY_ID, 5, new byte[48])
|
||||||
|
));
|
||||||
|
|
||||||
|
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||||
|
SealedSenderMultiRecipientMessage.parse(multiRecipientMessageBytes);
|
||||||
|
|
||||||
|
final Map<SealedSenderMultiRecipientMessage.Recipient, Account> resolvedRecipients = new HashMap<>();
|
||||||
|
|
||||||
|
multiRecipientMessage.getRecipients().forEach(((serviceId, recipient) -> {
|
||||||
|
if (serviceId.getRawUUID().equals(singleDeviceAccountAciServiceIdentifier.uuid()) ||
|
||||||
|
serviceId.getRawUUID().equals(singleDeviceAccountPniServiceIdentifier.uuid())) {
|
||||||
|
resolvedRecipients.put(recipient, singleDeviceAccount);
|
||||||
|
} else if (serviceId.getRawUUID().equals(multiDeviceAccountAciServiceIdentifier.uuid())) {
|
||||||
|
resolvedRecipients.put(recipient, multiDeviceAccount);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
final Map<Account, Map<Byte, Boolean>> expectedPresenceByAccountAndDeviceId = Map.of(
|
||||||
|
singleDeviceAccount, Map.of(Device.PRIMARY_ID, true),
|
||||||
|
multiDeviceAccount, Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
|
||||||
|
);
|
||||||
|
|
||||||
|
final Map<UUID, Map<Byte, Boolean>> presenceByAccountIdentifierAndDeviceId = Map.of(
|
||||||
|
singleDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, true),
|
||||||
|
multiDeviceAccountAciServiceIdentifier.uuid(), Map.of(Device.PRIMARY_ID, false, (byte) (Device.PRIMARY_ID + 1), true)
|
||||||
|
);
|
||||||
|
|
||||||
|
final byte[] sharedMrmKey = "shared-mrm-key".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
when(messagesCache.insertSharedMultiRecipientMessagePayload(multiRecipientMessage))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(sharedMrmKey));
|
||||||
|
|
||||||
|
when(messagesCache.insert(any(), any(), anyByte(), any()))
|
||||||
|
.thenAnswer(invocation -> {
|
||||||
|
final UUID accountIdentifier = invocation.getArgument(1);
|
||||||
|
final byte deviceId = invocation.getArgument(2);
|
||||||
|
|
||||||
|
return CompletableFuture.completedFuture(
|
||||||
|
presenceByAccountIdentifierAndDeviceId.getOrDefault(accountIdentifier, Collections.emptyMap())
|
||||||
|
.getOrDefault(deviceId, false));
|
||||||
|
});
|
||||||
|
|
||||||
|
final long clientTimestamp = System.currentTimeMillis();
|
||||||
|
final boolean isStory = ThreadLocalRandom.current().nextBoolean();
|
||||||
|
final boolean isEphemeral = ThreadLocalRandom.current().nextBoolean();
|
||||||
|
final boolean isUrgent = ThreadLocalRandom.current().nextBoolean();
|
||||||
|
|
||||||
|
final Envelope prototypeExpectedMessage = Envelope.newBuilder()
|
||||||
|
.setType(Envelope.Type.UNIDENTIFIED_SENDER)
|
||||||
|
.setClientTimestamp(clientTimestamp)
|
||||||
|
.setServerTimestamp(CLOCK.millis())
|
||||||
|
.setStory(isStory)
|
||||||
|
.setEphemeral(isEphemeral)
|
||||||
|
.setUrgent(isUrgent)
|
||||||
|
.setSharedMrmKey(ByteString.copyFrom(sharedMrmKey))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assertEquals(expectedPresenceByAccountAndDeviceId,
|
||||||
|
messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent).join());
|
||||||
|
|
||||||
|
verify(messagesCache).insert(any(),
|
||||||
|
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
|
||||||
|
eq(Device.PRIMARY_ID),
|
||||||
|
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
|
||||||
|
|
||||||
|
verify(messagesCache).insert(any(),
|
||||||
|
eq(singleDeviceAccountAciServiceIdentifier.uuid()),
|
||||||
|
eq(Device.PRIMARY_ID),
|
||||||
|
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(singleDeviceAccountPniServiceIdentifier.toServiceIdentifierString()).build()));
|
||||||
|
|
||||||
|
verify(messagesCache).insert(any(),
|
||||||
|
eq(multiDeviceAccountAciServiceIdentifier.uuid()),
|
||||||
|
eq((byte) (Device.PRIMARY_ID + 1)),
|
||||||
|
eq(prototypeExpectedMessage.toBuilder().setDestinationServiceId(multiDeviceAccountAciServiceIdentifier.toServiceIdentifierString()).build()));
|
||||||
|
|
||||||
|
verify(messagesCache, never()).insert(any(),
|
||||||
|
eq(unresolvedAccountAciServiceIdentifier.uuid()),
|
||||||
|
anyByte(),
|
||||||
|
any());
|
||||||
|
}
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@CsvSource({
|
@CsvSource({
|
||||||
"false, false, false",
|
"false, false, false",
|
||||||
|
|
|
@ -29,6 +29,7 @@ class ReportMessageDynamoDbTest {
|
||||||
void setUp() {
|
void setUp() {
|
||||||
this.reportMessageDynamoDb = new ReportMessageDynamoDb(
|
this.reportMessageDynamoDb = new ReportMessageDynamoDb(
|
||||||
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
|
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
|
||||||
|
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
|
||||||
Tables.REPORT_MESSAGES.tableName(),
|
Tables.REPORT_MESSAGES.tableName(),
|
||||||
Duration.ofDays(1));
|
Duration.ofDays(1));
|
||||||
}
|
}
|
||||||
|
@ -44,8 +45,8 @@ class ReportMessageDynamoDbTest {
|
||||||
() -> assertFalse(reportMessageDynamoDb.remove(hash2))
|
() -> assertFalse(reportMessageDynamoDb.remove(hash2))
|
||||||
);
|
);
|
||||||
|
|
||||||
reportMessageDynamoDb.store(hash1);
|
reportMessageDynamoDb.store(hash1).join();
|
||||||
reportMessageDynamoDb.store(hash2);
|
reportMessageDynamoDb.store(hash2).join();
|
||||||
|
|
||||||
assertAll("both hashes should be found",
|
assertAll("both hashes should be found",
|
||||||
() -> assertTrue(reportMessageDynamoDb.remove(hash1)),
|
() -> assertTrue(reportMessageDynamoDb.remove(hash1)),
|
||||||
|
|
|
@ -18,6 +18,7 @@ import static org.mockito.Mockito.when;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
|
@ -68,8 +69,8 @@ class ReportMessageManagerTest {
|
||||||
|
|
||||||
verify(reportMessageDynamoDb).store(any());
|
verify(reportMessageDynamoDb).store(any());
|
||||||
|
|
||||||
doThrow(RuntimeException.class)
|
when(reportMessageDynamoDb.store(any()))
|
||||||
.when(reportMessageDynamoDb).store(any());
|
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
|
||||||
|
|
||||||
assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid));
|
assertDoesNotThrow(() -> reportMessageManager.store(sourceAci.toString(), messageGuid));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.tests.util;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MultiRecipientMessageHelper {
|
||||||
|
|
||||||
|
private MultiRecipientMessageHelper() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients) {
|
||||||
|
return generateMultiRecipientMessage(recipients, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static byte[] generateMultiRecipientMessage(final List<TestRecipient> recipients, final int sharedPayloadSize) {
|
||||||
|
if (sharedPayloadSize < 32) {
|
||||||
|
throw new IllegalArgumentException("Shared payload size must be at least 32 bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
final ByteBuffer buffer = ByteBuffer.allocate(payloadSize(recipients, sharedPayloadSize));
|
||||||
|
|
||||||
|
// first write the header
|
||||||
|
buffer.put((byte) 0x23); // version byte
|
||||||
|
|
||||||
|
// count varint
|
||||||
|
writeVarint(buffer, recipients.size());
|
||||||
|
|
||||||
|
recipients.forEach(recipient -> {
|
||||||
|
buffer.put(recipient.uuid().toFixedWidthByteArray());
|
||||||
|
|
||||||
|
assert recipient.deviceIds().length == recipient.registrationIds().length;
|
||||||
|
|
||||||
|
for (int i = 0; i < recipient.deviceIds().length; i++) {
|
||||||
|
final int hasMore = i == recipient.deviceIds().length - 1 ? 0 : 0x8000;
|
||||||
|
buffer.put(recipient.deviceIds()[i]); // device id (1 byte)
|
||||||
|
buffer.putShort((short) (recipient.registrationIds()[i] | hasMore)); // registration id (2 bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.put(recipient.perRecipientKeyMaterial()); // key material (48 bytes)
|
||||||
|
});
|
||||||
|
|
||||||
|
// now write the actual message body (empty for now)
|
||||||
|
writeVarint(buffer, sharedPayloadSize);
|
||||||
|
buffer.put(new byte[sharedPayloadSize]);
|
||||||
|
|
||||||
|
return buffer.array();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void writeVarint(final ByteBuffer buffer, long n) {
|
||||||
|
if (n < 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
while (n >= 0x80) {
|
||||||
|
buffer.put ((byte) (n & 0x7F | 0x80));
|
||||||
|
n >>= 7;
|
||||||
|
}
|
||||||
|
buffer.put((byte) (n & 0x7F));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int payloadSize(final List<TestRecipient> recipients, final int sharedPayloadSize) {
|
||||||
|
final int fixedBytesPerRecipient = 17 // Service identifier length
|
||||||
|
+ 48; // Per-recipient key material
|
||||||
|
|
||||||
|
final int bytesForDevices = 3 * recipients.stream()
|
||||||
|
.mapToInt(recipient -> recipient.deviceIds().length)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
return 1 // Version byte
|
||||||
|
+ varintLength(recipients.size())
|
||||||
|
+ (recipients.size() * fixedBytesPerRecipient)
|
||||||
|
+ bytesForDevices
|
||||||
|
+ varintLength(sharedPayloadSize)
|
||||||
|
+ sharedPayloadSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int varintLength(long n) {
|
||||||
|
int length = 0;
|
||||||
|
|
||||||
|
while (n >= 0x80) {
|
||||||
|
length += 1;
|
||||||
|
n >>= 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
return length + 1;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.tests.util;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||||
|
|
||||||
|
public record TestRecipient(ServiceIdentifier uuid,
|
||||||
|
byte[] deviceIds,
|
||||||
|
int[] registrationIds,
|
||||||
|
byte[] perRecipientKeyMaterial) {
|
||||||
|
|
||||||
|
public TestRecipient(ServiceIdentifier uuid,
|
||||||
|
byte deviceId,
|
||||||
|
int registrationId,
|
||||||
|
byte[] perRecipientKeyMaterial) {
|
||||||
|
|
||||||
|
this(uuid, new byte[]{deviceId}, new int[]{registrationId}, perRecipientKeyMaterial);
|
||||||
|
}
|
||||||
|
}
|
|
@ -132,7 +132,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) {
|
void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount) {
|
||||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||||
mock(ReceiptSender.class),
|
mock(ReceiptSender.class),
|
||||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||||
new MessageMetrics(),
|
new MessageMetrics(),
|
||||||
mock(PushNotificationManager.class),
|
mock(PushNotificationManager.class),
|
||||||
mock(PushNotificationScheduler.class),
|
mock(PushNotificationScheduler.class),
|
||||||
|
@ -164,7 +164,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||||
|
|
||||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||||
expectedMessages.add(envelope);
|
expectedMessages.add(envelope);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
void testProcessStoredMessagesClientClosed() {
|
void testProcessStoredMessagesClientClosed() {
|
||||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||||
mock(ReceiptSender.class),
|
mock(ReceiptSender.class),
|
||||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||||
new MessageMetrics(),
|
new MessageMetrics(),
|
||||||
mock(PushNotificationManager.class),
|
mock(PushNotificationManager.class),
|
||||||
mock(PushNotificationScheduler.class),
|
mock(PushNotificationScheduler.class),
|
||||||
|
@ -253,7 +253,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
for (int i = 0; i < cachedMessageCount; i++) {
|
for (int i = 0; i < cachedMessageCount; i++) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||||
|
|
||||||
expectedMessages.add(envelope);
|
expectedMessages.add(envelope);
|
||||||
}
|
}
|
||||||
|
@ -289,7 +289,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
void testProcessStoredMessagesSendFutureTimeout() {
|
void testProcessStoredMessagesSendFutureTimeout() {
|
||||||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||||
mock(ReceiptSender.class),
|
mock(ReceiptSender.class),
|
||||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService, Clock.systemUTC()),
|
||||||
new MessageMetrics(),
|
new MessageMetrics(),
|
||||||
mock(PushNotificationManager.class),
|
mock(PushNotificationManager.class),
|
||||||
mock(PushNotificationScheduler.class),
|
mock(PushNotificationScheduler.class),
|
||||||
|
@ -323,7 +323,7 @@ class WebSocketConnectionIntegrationTest {
|
||||||
for (int i = 0; i < cachedMessageCount; i++) {
|
for (int i = 0; i < cachedMessageCount; i++) {
|
||||||
final UUID messageGuid = UUID.randomUUID();
|
final UUID messageGuid = UUID.randomUUID();
|
||||||
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
|
||||||
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope);
|
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
|
||||||
|
|
||||||
expectedMessages.add(envelope);
|
expectedMessages.add(envelope);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue