diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java index a0a3a45de..2ebca9d2a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java @@ -12,7 +12,6 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import org.signal.chat.messages.IndividualRecipientMessageBundle; -import org.signal.chat.messages.MismatchedDevices; import org.signal.chat.messages.MultiRecipientMismatchedDevices; import org.signal.chat.messages.SendMessageResponse; import org.signal.chat.messages.SendMultiRecipientMessageRequest; @@ -25,7 +24,6 @@ import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidVersionException; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -178,21 +176,11 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me entry -> entry.getKey().byteValue(), entry -> entry.getValue().getRegistrationId())); - try { - messageSender.sendMessages(destination, - destinationServiceIdentifier, - messagesByDeviceId, - registrationIdsByDeviceId, - RequestAttributesUtil.getRawUserAgent().orElse(null)); - - return SEND_MESSAGE_SUCCESS_RESPONSE; - } catch (final MismatchedDevicesException e) { - return SendMessageResponse.newBuilder() - .setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices())) - .build(); - } catch (final MessageTooLargeException e) { - throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException(); - } + return MessagesGrpcHelper.sendMessage(messageSender, + destination, + destinationServiceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId); } @Override @@ -276,7 +264,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me MultiRecipientMismatchedDevices.newBuilder(); e.getMismatchedDevicesByServiceIdentifier().forEach((serviceIdentifier, mismatchedDevices) -> - mismatchedDevicesBuilder.addMismatchedDevices(buildMismatchedDevices(serviceIdentifier, mismatchedDevices))); + mismatchedDevicesBuilder.addMismatchedDevices(MessagesGrpcHelper.buildMismatchedDevices(serviceIdentifier, mismatchedDevices))); return SendMultiRecipientMessageResponse.newBuilder() .setMismatchedDevices(mismatchedDevicesBuilder) @@ -303,17 +291,4 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me return multiRecipientMessage; } - - private MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier, - org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) { - - final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder() - .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)); - - mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices); - mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices); - mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices); - - return mismatchedDevicesBuilder.build(); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java new file mode 100644 index 000000000..f40b870e1 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java @@ -0,0 +1,85 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Status; +import io.grpc.StatusException; +import java.util.Map; +import org.signal.chat.messages.MismatchedDevices; +import org.signal.chat.messages.SendMessageResponse; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; +import org.whispersystems.textsecuregcm.storage.Account; + +public class MessagesGrpcHelper { + + private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build(); + + /** + * Sends a "bundle" of messages to an individual destination account, mapping common exceptions to appropriate gRPC + * statuses. + * + * @param messageSender the {@code MessageSender} instance to use to send the messages + * @param destination the destination account for the messages + * @param destinationServiceIdentifier the service identifier for the destination account + * @param messagesByDeviceId a map of device IDs to message payloads + * @param registrationIdsByDeviceId a map of device IDs to device registration IDs + * + * @return a response object to send to callers + * + * @throws StatusException if the message bundle could not be sent due to an out-of-date device set or an invalid + * message payload + * @throws RateLimitExceededException if the message bundle could not be sent due to a violated rated limit + */ + public static SendMessageResponse sendMessage(final MessageSender messageSender, + final Account destination, + final ServiceIdentifier destinationServiceIdentifier, + final Map messagesByDeviceId, + final Map registrationIdsByDeviceId) throws StatusException, RateLimitExceededException { + + try { + messageSender.sendMessages(destination, + destinationServiceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + RequestAttributesUtil.getRawUserAgent().orElse(null)); + + return SEND_MESSAGE_SUCCESS_RESPONSE; + } catch (final MismatchedDevicesException e) { + return SendMessageResponse.newBuilder() + .setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices())) + .build(); + } catch (final MessageTooLargeException e) { + throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException(); + } + } + + /** + * Translates an internal {@link org.whispersystems.textsecuregcm.controllers.MismatchedDevices} entity to a gRPC + * {@link MismatchedDevices} entity. + * + * @param serviceIdentifier the service identifier to which the mismatched device response applies + * @param mismatchedDevices the mismatched device entity to translate to gRPC + * + * @return a gRPC {@code MismatchedDevices} representation of the given mismatched devices + */ + public static MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier, + final org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) { + + final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)); + + mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices); + mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices); + mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices); + + return mismatchedDevicesBuilder.build(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java new file mode 100644 index 000000000..342fa25a2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java @@ -0,0 +1,187 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.StatusException; +import java.time.Clock; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.signal.chat.messages.AuthenticatedSenderMessageType; +import org.signal.chat.messages.IndividualRecipientMessageBundle; +import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest; +import org.signal.chat.messages.SendMessageResponse; +import org.signal.chat.messages.SendSyncMessageRequest; +import org.signal.chat.messages.SimpleMessagesGrpc; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.spam.GrpcResponse; +import org.whispersystems.textsecuregcm.spam.MessageType; +import org.whispersystems.textsecuregcm.spam.SpamCheckResult; +import org.whispersystems.textsecuregcm.spam.SpamChecker; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; + +public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase { + + private final AccountsManager accountsManager; + private final RateLimiters rateLimiters; + private final MessageSender messageSender; + private final CardinalityEstimator messageByteLimitEstimator; + private final SpamChecker spamChecker; + private final Clock clock; + + public MessagesGrpcService(final AccountsManager accountsManager, + final RateLimiters rateLimiters, + final MessageSender messageSender, + final CardinalityEstimator messageByteLimitEstimator, + final SpamChecker spamChecker, + final Clock clock) { + + this.accountsManager = accountsManager; + this.rateLimiters = rateLimiters; + this.messageSender = messageSender; + this.messageByteLimitEstimator = messageByteLimitEstimator; + this.spamChecker = spamChecker; + this.clock = clock; + } + + @Override + public SendMessageResponse sendMessage(final SendAuthenticatedSenderMessageRequest request) + throws StatusException, RateLimitExceededException { + + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier()); + final Account sender = + accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException); + + final ServiceIdentifier destinationServiceIdentifier = + ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination()); + + if (sender.isIdentifiedBy(destinationServiceIdentifier)) { + throw Status.INVALID_ARGUMENT + .withDescription("Use `sendSyncMessage` to send messages to own account") + .asException(); + } + + final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier) + .orElseThrow(Status.NOT_FOUND::asException); + + rateLimiters.getMessagesLimiter().validate(authenticatedDevice.accountIdentifier(), destination.getUuid()); + + return sendMessage(destination, + destinationServiceIdentifier, + authenticatedDevice, + request.getType(), + MessageType.INDIVIDUAL_IDENTIFIED_SENDER, + request.getMessages(), + request.getEphemeral(), + request.getUrgent()); + } + + @Override + public SendMessageResponse sendSyncMessage(final SendSyncMessageRequest request) + throws StatusException, RateLimitExceededException { + + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + final AciServiceIdentifier senderServiceIdentifier = new AciServiceIdentifier(authenticatedDevice.accountIdentifier()); + final Account sender = + accountsManager.getByServiceIdentifier(senderServiceIdentifier).orElseThrow(Status.UNAUTHENTICATED::asException); + + return sendMessage(sender, + senderServiceIdentifier, + authenticatedDevice, + request.getType(), + MessageType.SYNC, + request.getMessages(), + false, + request.getUrgent()); + } + + private SendMessageResponse sendMessage(final Account destination, + final ServiceIdentifier destinationServiceIdentifier, + final AuthenticatedDevice sender, + final AuthenticatedSenderMessageType envelopeType, + final MessageType messageType, + final IndividualRecipientMessageBundle messages, + final boolean ephemeral, + final boolean urgent) throws StatusException, RateLimitExceededException { + + try { + final int totalPayloadLength = messages.getMessagesMap().values().stream() + .mapToInt(message -> message.getPayload().size()) + .sum(); + + rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength); + } catch (final RateLimitExceededException e) { + messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString()); + throw e; + } + + final SpamCheckResult> spamCheckResult = + spamChecker.checkForIndividualRecipientSpamGrpc(messageType, + Optional.of(sender), + Optional.of(destination), + destinationServiceIdentifier); + + if (spamCheckResult.response().isPresent()) { + return spamCheckResult.response().get().getResponseOrThrowStatus(); + } + + final Map messagesByDeviceId = messages.getMessagesMap().entrySet() + .stream() + .collect(Collectors.toMap( + entry -> DeviceIdUtil.validate(entry.getKey()), + entry -> { + final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() + .setType(getEnvelopeType(envelopeType)) + .setClientTimestamp(messages.getTimestamp()) + .setServerTimestamp(clock.millis()) + .setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString()) + .setSourceServiceId(new AciServiceIdentifier(sender.accountIdentifier()).toServiceIdentifierString()) + .setSourceDevice(sender.deviceId()) + .setEphemeral(ephemeral) + .setUrgent(urgent) + .setContent(entry.getValue().getPayload()); + + spamCheckResult.token().ifPresent(reportSpamToken -> + envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken))); + + return envelopeBuilder.build(); + } + )); + + final Map registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream() + .collect(Collectors.toMap( + entry -> entry.getKey().byteValue(), + entry -> entry.getValue().getRegistrationId())); + + return MessagesGrpcHelper.sendMessage(messageSender, + destination, + destinationServiceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId); + } + + private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) { + return switch (type) { + case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT; + case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE; + case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT; + case UNSPECIFIED, UNRECOGNIZED -> + throw Status.INVALID_ARGUMENT.withDescription("Unrecognized envelope type").asRuntimeException(); + }; + } +} diff --git a/service/src/main/proto/org/signal/chat/messages.proto b/service/src/main/proto/org/signal/chat/messages.proto index 69baeb8f8..7faf84683 100644 --- a/service/src/main/proto/org/signal/chat/messages.proto +++ b/service/src/main/proto/org/signal/chat/messages.proto @@ -24,10 +24,12 @@ service Messages { * destination account. * * This RPC may fail with a `NOT_FOUND` status if the destination account was - * not found. It may also fail with a `RESOURCE_EXHAUSTED` status if a rate - * limit for sending messages has been exceeded, in which case a `retry-after` - * header containing an ISO 8601 duration string may be present in the - * response trailers. + * not found. It may also fail with an `INVALID_ARGUMENT` status if the + * destination account is the same as the authenticated caller (callers should + * use `SendSyncMessage` to send messages to themselves). It may also fail + * with a `RESOURCE_EXHAUSTED` status if a rate limit for sending messages has + * been exceeded, in which case a `retry-after` header containing an ISO 8601 + * duration string may be present in the response trailers. * * Note that message delivery may not succeed even if this RPC returns an `OK` * status; callers must check the response object to verify that the message diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java new file mode 100644 index 000000000..ec21a19d3 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java @@ -0,0 +1,616 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.mockito.Mock; +import org.signal.chat.messages.AuthenticatedSenderMessageType; +import org.signal.chat.messages.ChallengeRequired; +import org.signal.chat.messages.IndividualRecipientMessageBundle; +import org.signal.chat.messages.MessagesGrpc; +import org.signal.chat.messages.MismatchedDevices; +import org.signal.chat.messages.SendAuthenticatedSenderMessageRequest; +import org.signal.chat.messages.SendMessageResponse; +import org.signal.chat.messages.SendSyncMessageRequest; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; +import org.whispersystems.textsecuregcm.spam.GrpcResponse; +import org.whispersystems.textsecuregcm.spam.MessageType; +import org.whispersystems.textsecuregcm.spam.SpamCheckResult; +import org.whispersystems.textsecuregcm.spam.SpamChecker; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; +import org.whispersystems.textsecuregcm.util.TestClock; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; + +class MessagesGrpcServiceTest extends SimpleBaseGrpcTest { + + @Mock + private AccountsManager accountsManager; + + @Mock + private RateLimiters rateLimiters; + + @Mock + private MessageSender messageSender; + + @Mock + private CardinalityEstimator messageByteLimitEstimator; + + @Mock + private SpamChecker spamChecker; + + @Mock + private RateLimiter rateLimiter; + + @Mock + private Account authenticatedAccount; + + @Mock + private Device authenticatedDevice; + + @Mock + private Device linkedDevice; + + @Mock + private Device secondLinkedDevice; + + private static final int AUTHENTICATED_REGISTRATION_ID = 7; + + private static final byte LINKED_DEVICE_ID = AUTHENTICATED_DEVICE_ID + 1; + private static final int LINKED_DEVICE_REGISTRATION_ID = 13; + + private static final byte SECOND_LINKED_DEVICE_ID = LINKED_DEVICE_ID + 1; + private static final int SECOND_LINKED_DEVICE_REGISTRATION_ID = 19; + + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + + @Override + protected MessagesGrpcService createServiceBeforeEachTest() { + return new MessagesGrpcService(accountsManager, + rateLimiters, + messageSender, + messageByteLimitEstimator, + spamChecker, + CLOCK); + } + + @BeforeEach + void setUp() { + when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty()); + + when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); + when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty())); + + when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID); + when(authenticatedDevice.getRegistrationId()).thenReturn(AUTHENTICATED_REGISTRATION_ID); + + when(linkedDevice.getId()).thenReturn(LINKED_DEVICE_ID); + when(linkedDevice.getRegistrationId()).thenReturn(LINKED_DEVICE_REGISTRATION_ID); + + when(secondLinkedDevice.getId()).thenReturn(SECOND_LINKED_DEVICE_ID); + when(secondLinkedDevice.getRegistrationId()).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID); + + when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI); + when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI); + when(authenticatedAccount.getDevice(anyByte())).thenReturn(Optional.empty()); + when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice)); + when(authenticatedAccount.getDevice(LINKED_DEVICE_ID)).thenReturn(Optional.of(linkedDevice)); + when(authenticatedAccount.getDevice(SECOND_LINKED_DEVICE_ID)).thenReturn(Optional.of(secondLinkedDevice)); + when(authenticatedAccount.getDevices()).thenReturn(List.of(authenticatedDevice, linkedDevice, secondLinkedDevice)); + + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI))) + .thenReturn(Optional.of(authenticatedAccount)); + } + + @Nested + class SingleRecipient { + + @CartesianTest + void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType, + @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent, + @CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken) + throws MessageTooLargeException, MismatchedDevicesException { + + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final byte[] reportSpamToken = TestRandomUtil.nextBytes(64); + + if (includeReportSpamToken) { + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken))); + } + + final byte[] payload = TestRandomUtil.nextBytes(128); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(payload)) + .build()); + + final SendMessageResponse response = authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, messageType, ephemeral, urgent, messages)); + + assertEquals(SendMessageResponse.newBuilder().build(), response); + + final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) { + case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT; + case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE; + case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT; + case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType); + }; + + final MessageProtos.Envelope.Builder expectedEnvelopeBuilder = MessageProtos.Envelope.newBuilder() + .setType(expectedEnvelopeType) + .setSourceServiceId(new AciServiceIdentifier(AUTHENTICATED_ACI).toServiceIdentifierString()) + .setSourceDevice(AUTHENTICATED_DEVICE_ID) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setEphemeral(ephemeral) + .setUrgent(urgent) + .setContent(ByteString.copyFrom(payload)); + + if (includeReportSpamToken) { + expectedEnvelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)); + } + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER, + Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender).sendMessages(destinationAccount, + serviceIdentifier, + Map.of(deviceId, expectedEnvelopeBuilder.build()), + Map.of(deviceId, registrationId), + null); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + final SendMessageResponse response = authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages)); + + final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder() + .setMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .addMissingDevices(missingDeviceId) + .addStaleDevices(staleDeviceId) + .addExtraDevices(extraDeviceId) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void destinationNotFound() throws MessageTooLargeException, MismatchedDevicesException { + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + final Map messages = + Map.of(Device.PRIMARY_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(1234) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.NOT_FOUND, + () -> authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Duration retryDuration = Duration.ofHours(7); + + doThrow(new RateLimitExceededException(retryDuration)) + .when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyInt()); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertRateLimitExceeded(retryDuration, + () -> authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, + () -> authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); + } + + @Test + void spamWithStatus() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED, + () -> authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER, + Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void spamWithResponse() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + final SendMessageResponse response = SendMessageResponse.newBuilder() + .setChallengeRequired(ChallengeRequired.newBuilder() + .addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA)) + .build(); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty())); + + assertEquals(response, authenticatedServiceStub().sendMessage( + generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_IDENTIFIED_SENDER, + Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + private static SendAuthenticatedSenderMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier, + final AuthenticatedSenderMessageType messageType, + final boolean ephemeral, + final boolean urgent, + final Map messages) { + + final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder() + .setTimestamp(CLOCK.millis()); + + messages.forEach(messageBundleBuilder::putMessages); + + final SendAuthenticatedSenderMessageRequest.Builder requestBuilder = SendAuthenticatedSenderMessageRequest.newBuilder() + .setDestination(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .setType(messageType) + .setMessages(messageBundleBuilder) + .setEphemeral(ephemeral) + .setUrgent(urgent); + + return requestBuilder.build(); + } + } + + @Nested + class Sync { + + @CartesianTest + void sendMessage(@CartesianTest.Enum(mode = CartesianTest.Enum.Mode.EXCLUDE, names = {"UNSPECIFIED", "UNRECOGNIZED"}) final AuthenticatedSenderMessageType messageType, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent, + @CartesianTest.Values(booleans = {true, false}) final boolean includeReportSpamToken) + throws MessageTooLargeException, MismatchedDevicesException { + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(AUTHENTICATED_ACI); + final byte[] payload = TestRandomUtil.nextBytes(128); + + final Map messages = + Map.of(LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(LINKED_DEVICE_REGISTRATION_ID) + .setPayload(ByteString.copyFrom(payload)) + .build(), + + SECOND_LINKED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(SECOND_LINKED_DEVICE_REGISTRATION_ID) + .setPayload(ByteString.copyFrom(payload)) + .build()); + + final byte[] reportSpamToken = TestRandomUtil.nextBytes(64); + + if (includeReportSpamToken) { + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.of(reportSpamToken))); + } + + final SendMessageResponse response = + authenticatedServiceStub().sendSyncMessage(generateRequest(messageType, urgent, messages)); + + assertEquals(SendMessageResponse.newBuilder().build(), response); + + final MessageProtos.Envelope.Type expectedEnvelopeType = switch (messageType) { + case DOUBLE_RATCHET -> MessageProtos.Envelope.Type.CIPHERTEXT; + case PREKEY_MESSAGE -> MessageProtos.Envelope.Type.PREKEY_BUNDLE; + case PLAINTEXT_CONTENT -> MessageProtos.Envelope.Type.PLAINTEXT_CONTENT; + case UNSPECIFIED, UNRECOGNIZED -> throw new IllegalArgumentException("Unexpected message type: " + messageType); + }; + + final Map expectedEnvelopes = new HashMap<>(Map.of( + LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder() + .setType(expectedEnvelopeType) + .setSourceServiceId(serviceIdentifier.toServiceIdentifierString()) + .setSourceDevice(AUTHENTICATED_DEVICE_ID) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setEphemeral(false) + .setUrgent(urgent) + .setContent(ByteString.copyFrom(payload)) + .build(), + + SECOND_LINKED_DEVICE_ID, MessageProtos.Envelope.newBuilder() + .setType(expectedEnvelopeType) + .setSourceServiceId(serviceIdentifier.toServiceIdentifierString()) + .setSourceDevice(AUTHENTICATED_DEVICE_ID) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setEphemeral(false) + .setUrgent(urgent) + .setContent(ByteString.copyFrom(payload)) + .build() + )); + + if (includeReportSpamToken) { + expectedEnvelopes.replaceAll((deviceId, envelope) -> + envelope.toBuilder().setReportSpamToken(ByteString.copyFrom(reportSpamToken)).build()); + } + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.SYNC, + Optional.of(new AuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)), + Optional.of(authenticatedAccount), + serviceIdentifier); + + verify(messageSender).sendMessages(authenticatedAccount, + serviceIdentifier, + expectedEnvelopes, + Map.of(LINKED_DEVICE_ID, LINKED_DEVICE_REGISTRATION_ID, + SECOND_LINKED_DEVICE_ID, SECOND_LINKED_DEVICE_REGISTRATION_ID), + null); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + final SendMessageResponse response = authenticatedServiceStub().sendSyncMessage( + generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages)); + + final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder() + .setMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(new AciServiceIdentifier(AUTHENTICATED_ACI))) + .addMissingDevices(missingDeviceId) + .addStaleDevices(staleDeviceId) + .addExtraDevices(extraDeviceId) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException { + final Duration retryDuration = Duration.ofHours(7); + doThrow(new RateLimitExceededException(retryDuration)) + .when(rateLimiter).validate(eq(AUTHENTICATED_ACI), anyInt()); + + final Map messages = + Map.of(AUTHENTICATED_DEVICE_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(AUTHENTICATED_REGISTRATION_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertRateLimitExceeded(retryDuration, + () -> authenticatedServiceStub().sendSyncMessage( + generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageByteLimitEstimator).add(AUTHENTICATED_ACI.toString()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, + () -> authenticatedServiceStub().sendSyncMessage( + generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages))); + } + + private static SendSyncMessageRequest generateRequest(final AuthenticatedSenderMessageType messageType, + final boolean urgent, + final Map messages) { + + final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder() + .setTimestamp(CLOCK.millis()); + + messages.forEach(messageBundleBuilder::putMessages); + + final SendSyncMessageRequest.Builder requestBuilder = SendSyncMessageRequest.newBuilder() + .setType(messageType) + .setMessages(messageBundleBuilder) + .setUrgent(urgent); + + return requestBuilder.build(); + } + } +}