diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GroupSendTokenUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GroupSendTokenUtil.java index 60380a932..3ae56c94d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GroupSendTokenUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/GroupSendTokenUtil.java @@ -9,6 +9,7 @@ import com.google.protobuf.ByteString; import io.grpc.Status; import io.grpc.StatusException; import java.time.Clock; +import java.util.Collection; import java.util.List; import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.zkgroup.InvalidInputException; @@ -29,11 +30,16 @@ public class GroupSendTokenUtil { } public void checkGroupSendToken(final ByteString serializedGroupSendToken, - final List serviceIdentifiers) throws StatusException { + final ServiceIdentifier serviceIdentifier) throws StatusException { + + checkGroupSendToken(serializedGroupSendToken, List.of(serviceIdentifier.toLibsignal())); + } + + public void checkGroupSendToken(final ByteString serializedGroupSendToken, + final Collection serviceIds) throws StatusException { try { final GroupSendFullToken token = new GroupSendFullToken(serializedGroupSendToken.toByteArray()); - final List serviceIds = serviceIdentifiers.stream().map(ServiceIdentifier::toLibsignal).toList(); token.verify(serviceIds, clock.instant(), GroupSendDerivedKeyPair.forExpiration(token.getExpiration(), serverSecretParams)); } catch (final InvalidInputException e) { throw Status.INVALID_ARGUMENT.asException(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java index 2bbb19239..9e35b55c8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java @@ -7,13 +7,11 @@ package org.whispersystems.textsecuregcm.grpc; import com.google.protobuf.ByteString; import io.grpc.Status; +import io.grpc.StatusException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Clock; import java.util.Arrays; -import java.util.List; - -import io.grpc.StatusException; import org.signal.chat.keys.CheckIdentityKeyRequest; import org.signal.chat.keys.CheckIdentityKeyResponse; import org.signal.chat.keys.GetPreKeysAnonymousRequest; @@ -55,7 +53,7 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony return switch (request.getAuthorizationCase()) { case GROUP_SEND_TOKEN -> { try { - groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(serviceIdentifier)); + groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), serviceIdentifier); yield lookUpAccount(serviceIdentifier, Status.NOT_FOUND) .flatMap(targetAccount -> KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java new file mode 100644 index 000000000..e26bf33ee --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java @@ -0,0 +1,282 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.StatusException; +import java.time.Clock; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import org.signal.chat.messages.MismatchedDevices; +import org.signal.chat.messages.MultiRecipientMismatchedDevices; +import org.signal.chat.messages.SendMessageResponse; +import org.signal.chat.messages.SendMultiRecipientMessageRequest; +import org.signal.chat.messages.SendMultiRecipientMessageResponse; +import org.signal.chat.messages.SendSealedSenderMessageRequest; +import org.signal.chat.messages.SimpleMessagesAnonymousGrpc; +import org.signal.libsignal.protocol.InvalidMessageException; +import org.signal.libsignal.protocol.InvalidVersionException; +import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; +import org.whispersystems.textsecuregcm.spam.GrpcResponse; +import org.whispersystems.textsecuregcm.spam.MessageType; +import org.whispersystems.textsecuregcm.spam.SpamCheckResult; +import org.whispersystems.textsecuregcm.spam.SpamChecker; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.MessagesAnonymousImplBase { + + private final AccountsManager accountsManager; + private final RateLimiters rateLimiters; + private final MessageSender messageSender; + private final GroupSendTokenUtil groupSendTokenUtil; + private final CardinalityEstimator messageByteLimitEstimator; + private final SpamChecker spamChecker; + private final Clock clock; + + private static final SendMessageResponse SEND_MESSAGE_SUCCESS_RESPONSE = SendMessageResponse.newBuilder().build(); + + private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8; + + public MessagesAnonymousGrpcService(final AccountsManager accountsManager, + final RateLimiters rateLimiters, + final MessageSender messageSender, + final GroupSendTokenUtil groupSendTokenUtil, + final CardinalityEstimator messageByteLimitEstimator, + final SpamChecker spamChecker, + final Clock clock) { + + this.accountsManager = accountsManager; + this.rateLimiters = rateLimiters; + this.messageSender = messageSender; + this.messageByteLimitEstimator = messageByteLimitEstimator; + this.spamChecker = spamChecker; + this.clock = clock; + this.groupSendTokenUtil = groupSendTokenUtil; + } + + @Override + public SendMessageResponse sendSingleRecipientMessage(final SendSealedSenderMessageRequest request) + throws StatusException, RateLimitExceededException { + + final ServiceIdentifier destinationServiceIdentifier = + ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination()); + + final Account destination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier) + .orElseThrow(Status.UNAUTHENTICATED::asException); + + switch (request.getAuthorizationCase()) { + case UNIDENTIFIED_ACCESS_KEY -> { + if (!UnidentifiedAccessUtil.checkUnidentifiedAccess(destination, request.getUnidentifiedAccessKey().toByteArray())) { + throw Status.UNAUTHENTICATED.asException(); + } + } + case GROUP_SEND_TOKEN -> + groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), destinationServiceIdentifier); + + case AUTHORIZATION_NOT_SET -> throw Status.UNAUTHENTICATED.asException(); + } + + final SpamCheckResult> spamCheckResult = + spamChecker.checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_SEALED_SENDER, + Optional.empty(), + Optional.of(destination), + destinationServiceIdentifier); + + if (spamCheckResult.response().isPresent()) { + final GrpcResponse response = spamCheckResult.response().get(); + + if (response.response().isPresent()) { + return response.response().get(); + } + + throw response.status().asException(); + } + + try { + final int totalPayloadLength = request.getMessages().getMessagesMap().values().stream() + .mapToInt(message -> message.getPayload().size()) + .sum(); + + rateLimiters.getInboundMessageBytes().validate(destinationServiceIdentifier.uuid(), totalPayloadLength); + } catch (final RateLimitExceededException e) { + messageByteLimitEstimator.add(destinationServiceIdentifier.uuid().toString()); + throw e; + } + + final Map messagesByDeviceId = request.getMessages().getMessagesMap().entrySet() + .stream() + .collect(Collectors.toMap( + entry -> DeviceIdUtil.validate(entry.getKey()), + entry -> { + final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(request.getMessages().getTimestamp()) + .setServerTimestamp(clock.millis()) + .setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString()) + .setEphemeral(request.getEphemeral()) + .setUrgent(request.getUrgent()) + .setContent(entry.getValue().getPayload()); + + spamCheckResult.token() + .ifPresent(token -> envelopeBuilder.setReportSpamToken(ByteString.copyFrom(token))); + + return envelopeBuilder.build(); + } + )); + + final Map registrationIdsByDeviceId = request.getMessages().getMessagesMap().entrySet().stream() + .collect(Collectors.toMap( + entry -> entry.getKey().byteValue(), + entry -> entry.getValue().getRegistrationId())); + + try { + messageSender.sendMessages(destination, + destinationServiceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + RequestAttributesUtil.getRawUserAgent().orElse(null)); + + return SEND_MESSAGE_SUCCESS_RESPONSE; + } catch (final MismatchedDevicesException e) { + return SendMessageResponse.newBuilder() + .setMismatchedDevices(buildMismatchedDevices(destinationServiceIdentifier, e.getMismatchedDevices())) + .build(); + } catch (final MessageTooLargeException e) { + throw Status.INVALID_ARGUMENT.withDescription("Message too large").withCause(e).asException(); + } + } + + @Override + public SendMultiRecipientMessageResponse sendMultiRecipientMessage(final SendMultiRecipientMessageRequest request) + throws StatusException { + + final SealedSenderMultiRecipientMessage multiRecipientMessage; + + try { + multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(request.getMessage().getPayload().toByteArray()); + } catch (final InvalidMessageException | InvalidVersionException e) { + throw Status.INVALID_ARGUMENT.withCause(e).asException(); + } + + // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same + // recipient + { + final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID]; + + for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { + Arrays.fill(usedDeviceIds, false); + + for (final byte deviceId : recipient.getDevices()) { + if (usedDeviceIds[deviceId]) { + throw Status.INVALID_ARGUMENT.withDescription("Request contains repeated device entries").asException(); + } + + usedDeviceIds[deviceId] = true; + } + } + } + + groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), multiRecipientMessage.getRecipients().keySet()); + + final SpamCheckResult> spamCheckResult = + spamChecker.checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_SEALED_SENDER); + + if (spamCheckResult.response().isPresent()) { + final GrpcResponse response = spamCheckResult.response().get(); + + if (response.response().isPresent()) { + return response.response().get(); + } + + throw response.status().asException(); + } + + // At this point, the caller has at least superficially provided the information needed to send a multi-recipient + // message. Attempt to resolve the destination service identifiers to Signal accounts. + final Map resolvedRecipients = + Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) + .flatMap(serviceIdAndRecipient -> { + final ServiceIdentifier serviceIdentifier = + ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey()); + + return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .flatMap(Mono::justOrEmpty) + .map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account)); + }, MAX_FETCH_ACCOUNT_CONCURRENCY) + .collectMap(Tuple2::getT1, Tuple2::getT2) + .blockOptional() + .orElse(Collections.emptyMap()); + + try { + messageSender.sendMultiRecipientMessage(multiRecipientMessage, + resolvedRecipients, + request.getMessage().getTimestamp(), + false, + request.getEphemeral(), + request.getUrgent(), + RequestAttributesUtil.getRawUserAgent().orElse(null)); + + final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder(); + + multiRecipientMessage.getRecipients().entrySet() + .stream() + .filter(entry -> !resolvedRecipients.containsKey(entry.getValue())) + .map(entry -> ServiceIdentifier.fromLibsignal(entry.getKey())) + .map(ServiceIdentifierUtil::toGrpcServiceIdentifier) + .forEach(responseBuilder::addUnresolvedRecipients); + + return responseBuilder.build(); + } catch (final MessageTooLargeException e) { + throw Status.INVALID_ARGUMENT + .withDescription("Message for an individual recipient was too large") + .withCause(e) + .asRuntimeException(); + } catch (final MultiRecipientMismatchedDevicesException e) { + final MultiRecipientMismatchedDevices.Builder mismatchedDevicesBuilder = + MultiRecipientMismatchedDevices.newBuilder(); + + e.getMismatchedDevicesByServiceIdentifier().forEach((serviceIdentifier, mismatchedDevices) -> + mismatchedDevicesBuilder.addMismatchedDevices(buildMismatchedDevices(serviceIdentifier, mismatchedDevices))); + + return SendMultiRecipientMessageResponse.newBuilder() + .setMismatchedDevices(mismatchedDevicesBuilder) + .build(); + } + } + + private MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier, + org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) { + + final MismatchedDevices.Builder mismatchedDevicesBuilder = MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)); + + mismatchedDevices.missingDeviceIds().forEach(mismatchedDevicesBuilder::addMissingDevices); + mismatchedDevices.extraDeviceIds().forEach(mismatchedDevicesBuilder::addExtraDevices); + mismatchedDevices.staleDeviceIds().forEach(mismatchedDevicesBuilder::addStaleDevices); + + return mismatchedDevicesBuilder.build(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcService.java index 1814173d1..46696ab58 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/ProfileAnonymousGrpcService.java @@ -6,11 +6,8 @@ package org.whispersystems.textsecuregcm.grpc; import io.grpc.Status; - -import java.time.Clock; -import java.util.List; - import io.grpc.StatusException; +import java.time.Clock; import org.signal.chat.profile.CredentialType; import org.signal.chat.profile.GetExpiringProfileKeyCredentialAnonymousRequest; import org.signal.chat.profile.GetExpiringProfileKeyCredentialResponse; @@ -62,7 +59,7 @@ public class ProfileAnonymousGrpcService extends ReactorProfileAnonymousGrpc.Pro final Mono account = switch (request.getAuthenticationCase()) { case GROUP_SEND_TOKEN -> { try { - groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), List.of(targetIdentifier)); + groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), targetIdentifier); yield Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier)) .flatMap(Mono::justOrEmpty) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java new file mode 100644 index 000000000..4f03d9610 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java @@ -0,0 +1,792 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.StatusException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.mockito.Mock; +import org.signal.chat.messages.ChallengeRequired; +import org.signal.chat.messages.IndividualRecipientMessageBundle; +import org.signal.chat.messages.MessagesAnonymousGrpc; +import org.signal.chat.messages.MismatchedDevices; +import org.signal.chat.messages.MultiRecipientMessage; +import org.signal.chat.messages.MultiRecipientMismatchedDevices; +import org.signal.chat.messages.SendMessageResponse; +import org.signal.chat.messages.SendMultiRecipientMessageRequest; +import org.signal.chat.messages.SendMultiRecipientMessageResponse; +import org.signal.chat.messages.SendSealedSenderMessageRequest; +import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.MessageTooLargeException; +import org.whispersystems.textsecuregcm.spam.GrpcResponse; +import org.whispersystems.textsecuregcm.spam.MessageType; +import org.whispersystems.textsecuregcm.spam.SpamCheckResult; +import org.whispersystems.textsecuregcm.spam.SpamChecker; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; +import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper; +import org.whispersystems.textsecuregcm.tests.util.TestRecipient; +import org.whispersystems.textsecuregcm.util.TestClock; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; + +class MessagesAnonymousGrpcServiceTest extends + SimpleBaseGrpcTest { + + @Mock + private AccountsManager accountsManager; + + @Mock + private RateLimiters rateLimiters; + + @Mock + private MessageSender messageSender; + + @Mock + private GroupSendTokenUtil groupSendTokenUtil; + + @Mock + private CardinalityEstimator messageByteLimitEstimator; + + @Mock + private SpamChecker spamChecker; + + @Mock + private RateLimiter rateLimiter; + + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + + private static final byte[] UNIDENTIFIED_ACCESS_KEY = + TestRandomUtil.nextBytes(UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH); + + private static final byte[] GROUP_SEND_TOKEN = TestRandomUtil.nextBytes(64); + + @Override + protected MessagesAnonymousGrpcService createServiceBeforeEachTest() { + return new MessagesAnonymousGrpcService(accountsManager, + rateLimiters, + messageSender, + groupSendTokenUtil, + messageByteLimitEstimator, + spamChecker, + CLOCK); + } + + @BeforeEach + void setUp() throws StatusException { + when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); + + doThrow(Status.UNAUTHENTICATED.asException()).when(groupSendTokenUtil) + .checkGroupSendToken(any(), any(ServiceIdentifier.class)); + + doThrow(Status.UNAUTHENTICATED.asException()).when(groupSendTokenUtil) + .checkGroupSendToken(any(), anyCollection()); + + doAnswer(invocation -> null).when(groupSendTokenUtil) + .checkGroupSendToken(eq(ByteString.copyFrom(GROUP_SEND_TOKEN)), any(ServiceIdentifier.class)); + + doAnswer(invocation -> null).when(groupSendTokenUtil) + .checkGroupSendToken(eq(ByteString.copyFrom(GROUP_SEND_TOKEN)), anyCollection()); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty())); + + when(spamChecker.checkForMultiRecipientSpamGrpc(any())) + .thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty())); + } + + @Nested + class SingleRecipient { + + @CartesianTest + void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean useUak, + @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent) + throws MessageTooLargeException, MismatchedDevicesException { + + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final byte[] payload = TestRandomUtil.nextBytes(128); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(payload)) + .build()); + + final SendMessageResponse response = unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, ephemeral, urgent, messages, + useUak ? UNIDENTIFIED_ACCESS_KEY : null, + useUak ? null : GROUP_SEND_TOKEN)); + + assertEquals(SendMessageResponse.newBuilder().build(), response); + + final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setEphemeral(ephemeral) + .setUrgent(urgent) + .setContent(ByteString.copyFrom(payload)) + .build(); + + verify(messageSender).sendMessages(destinationAccount, + serviceIdentifier, + Map.of(deviceId, expectedEnvelope), + Map.of(deviceId, registrationId), + null); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + final SendMessageResponse response = unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null)); + + final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder() + .setMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .addMissingDevices(missingDeviceId) + .addStaleDevices(staleDeviceId) + .addExtraDevices(extraDeviceId) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void badCredentials(final boolean useUak) throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + final byte[] incorrectUnidentifiedAccessKey = UNIDENTIFIED_ACCESS_KEY.clone(); + incorrectUnidentifiedAccessKey[0] += 1; + + final byte[] incorrectGroupSendToken = GROUP_SEND_TOKEN.clone(); + incorrectGroupSendToken[0] += 1; + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, + () -> unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, + useUak ? incorrectUnidentifiedAccessKey : null, + useUak ? null : incorrectGroupSendToken))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void destinationNotFound() throws MessageTooLargeException, MismatchedDevicesException { + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + final Map messages = + Map.of(Device.PRIMARY_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(1234) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, + () -> unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Duration retryDuration = Duration.ofHours(7); + + doThrow(new RateLimitExceededException(retryDuration)).when(rateLimiter).validate(eq(serviceIdentifier.uuid()), anyInt()); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertRateLimitExceeded(retryDuration, + () -> unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, + () -> unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); + } + + @Test + void spamWithStatus() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED, + () -> unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_SEALED_SENDER, + Optional.empty(), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void spamWithResponse() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + final SendMessageResponse response = SendMessageResponse.newBuilder() + .setChallengeRequired(ChallengeRequired.newBuilder() + .addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA)) + .build(); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty())); + + assertEquals(response, unauthenticatedServiceStub().sendSingleRecipientMessage( + generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_SEALED_SENDER, + Optional.empty(), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + private static SendSealedSenderMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier, + final boolean ephemeral, + final boolean urgent, + final Map messages, + @Nullable final byte[] unidentifiedAccessKey, + @Nullable final byte[] groupSendToken) { + + final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder() + .setTimestamp(CLOCK.millis()); + + messages.forEach(messageBundleBuilder::putMessages); + + final SendSealedSenderMessageRequest.Builder requestBuilder = SendSealedSenderMessageRequest.newBuilder() + .setDestination(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .setMessages(messageBundleBuilder) + .setEphemeral(ephemeral) + .setUrgent(urgent); + + if (unidentifiedAccessKey != null) { + requestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)); + } + + if (groupSendToken != null) { + requestBuilder.setGroupSendToken(ByteString.copyFrom(groupSendToken)); + } + + return requestBuilder.build(); + } + } + + @Nested + class MultiRecipient { + + @CartesianTest + void sendMessage(@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, + @CartesianTest.Values(booleans = {true, false}) final boolean urgent) + throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account resolvedAccount = mock(Account.class); + when(resolvedAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(resolvedAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier resolvedServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final AciServiceIdentifier unresolvedServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(resolvedServiceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(resolvedAccount))); + + final TestRecipient resolvedRecipient = + new TestRecipient(resolvedServiceIdentifier, deviceId, registrationId, new byte[48]); + + final TestRecipient unresolvedRecipient = + new TestRecipient(unresolvedServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + resolvedRecipient, unresolvedRecipient)); + + final SendMultiRecipientMessageRequest request = SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(ephemeral) + .setUrgent(urgent) + .build(); + + final SendMultiRecipientMessageResponse response = + unauthenticatedServiceStub().sendMultiRecipientMessage(request); + + final SendMultiRecipientMessageResponse expectedResponse = SendMultiRecipientMessageResponse.newBuilder() + .addUnresolvedRecipients(ServiceIdentifierUtil.toGrpcServiceIdentifier(unresolvedServiceIdentifier)) + .build(); + + assertEquals(expectedResponse, response); + + verify(messageSender).sendMultiRecipientMessage(any(), + argThat(resolvedRecipients -> resolvedRecipients.containsValue(resolvedAccount)), + eq(CLOCK.millis()), + eq(false), + eq(ephemeral), + eq(urgent), + isNull()); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(serviceIdentifier, staleDeviceId, 17, new byte[48]))); + + final SendMultiRecipientMessageRequest request = SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build(); + + doThrow(new MultiRecipientMismatchedDevicesException(Map.of(serviceIdentifier, + new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + final SendMultiRecipientMessageResponse response = + unauthenticatedServiceStub().sendMultiRecipientMessage(request); + + final SendMultiRecipientMessageResponse expectedResponse = SendMultiRecipientMessageResponse.newBuilder() + .setMismatchedDevices(MultiRecipientMismatchedDevices.newBuilder() + .addMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .addMissingDevices(missingDeviceId) + .addExtraDevices(extraDeviceId) + .addStaleDevices(staleDeviceId) + .build()) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void badCredentials() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]); + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient)); + + final byte[] incorrectGroupSendToken = GROUP_SEND_TOKEN.clone(); + incorrectGroupSendToken[0] += 1; + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage(SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(incorrectGroupSendToken)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.UNAUTHENTICATED, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage(SendMultiRecipientMessageRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build())); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void badPayload() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage(SendMultiRecipientMessageRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()) + .build())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage(SendMultiRecipientMessageRequest.newBuilder().build())); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void repeatedRecipient() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final Device destinationDevice = DevicesHelper.createDevice(Device.PRIMARY_ID, CLOCK.millis(), 1); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = new TestRecipient(serviceIdentifier, Device.PRIMARY_ID, 1, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient, recipient)); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage(SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build())); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(serviceIdentifier, Device.PRIMARY_ID, 17, new byte[48]))); + + final SendMultiRecipientMessageRequest request = SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build(); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, + () -> unauthenticatedServiceStub().sendMultiRecipientMessage(request)); + } + + @Test + void spamWithStatus() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = + new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient)); + + final SendMultiRecipientMessageRequest request = SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build(); + + when(spamChecker.checkForMultiRecipientSpamGrpc(any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED, + () -> unauthenticatedServiceStub().sendMultiRecipientMessage(request)); + + verify(spamChecker).checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_SEALED_SENDER); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void spamWithResponse() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = + new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient)); + + final SendMultiRecipientMessageRequest request = SendMultiRecipientMessageRequest.newBuilder() + .setGroupSendToken(ByteString.copyFrom(GROUP_SEND_TOKEN)) + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setEphemeral(false) + .setUrgent(true) + .build(); + + final SendMultiRecipientMessageResponse response = SendMultiRecipientMessageResponse.newBuilder() + .setChallengeRequired(ChallengeRequired.newBuilder() + .addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA)) + .build(); + + when(spamChecker.checkForMultiRecipientSpamGrpc(any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty())); + + assertEquals(response, unauthenticatedServiceStub().sendMultiRecipientMessage(request)); + + verify(spamChecker).checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_SEALED_SENDER); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + } +}