From caa81b4885b97f2e6e806dc26b26048aec5b8c46 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 7 Apr 2025 17:37:34 -0400 Subject: [PATCH] Implement story sending via gRPC --- .../grpc/MessagesAnonymousGrpcService.java | 220 +++++--- .../MessagesAnonymousGrpcServiceTest.java | 519 +++++++++++++++++- 2 files changed, 675 insertions(+), 64 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java index cf7a26a90..c095d7969 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java @@ -5,7 +5,6 @@ package org.whispersystems.textsecuregcm.grpc; -import com.google.protobuf.ByteString; import io.grpc.Status; import io.grpc.StatusException; import java.time.Clock; @@ -14,12 +13,15 @@ import java.util.Collections; import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import org.signal.chat.messages.IndividualRecipientMessageBundle; import org.signal.chat.messages.MismatchedDevices; import org.signal.chat.messages.MultiRecipientMismatchedDevices; import org.signal.chat.messages.SendMessageResponse; import org.signal.chat.messages.SendMultiRecipientMessageRequest; import org.signal.chat.messages.SendMultiRecipientMessageResponse; +import org.signal.chat.messages.SendMultiRecipientStoryRequest; import org.signal.chat.messages.SendSealedSenderMessageRequest; +import org.signal.chat.messages.SendStoryMessageRequest; import org.signal.chat.messages.SimpleMessagesAnonymousGrpc; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidVersionException; @@ -29,6 +31,7 @@ import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -99,8 +102,50 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me case AUTHORIZATION_NOT_SET -> throw Status.UNAUTHENTICATED.asException(); } + return sendIndividualMessage(destination, + destinationServiceIdentifier, + request.getMessages(), + request.getEphemeral(), + request.getUrgent(), + false); + } + + @Override + public SendMessageResponse sendStory(final SendStoryMessageRequest request) + throws StatusException, RateLimitExceededException { + + final ServiceIdentifier destinationServiceIdentifier = + ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getDestination()); + + final Optional maybeDestination = accountsManager.getByServiceIdentifier(destinationServiceIdentifier); + + if (maybeDestination.isEmpty()) { + // Don't reveal to unauthenticated callers whether a destination account actually exists + return SEND_MESSAGE_SUCCESS_RESPONSE; + } + + final Account destination = maybeDestination.get(); + + rateLimiters.getStoriesLimiter().validate(destination.getIdentifier(IdentityType.ACI)); + + return sendIndividualMessage(destination, + destinationServiceIdentifier, + request.getMessages(), + false, + request.getUrgent(), + true); + } + + private SendMessageResponse sendIndividualMessage(final Account destination, + final ServiceIdentifier destinationServiceIdentifier, + final IndividualRecipientMessageBundle messages, + final boolean ephemeral, + final boolean urgent, + final boolean story) throws StatusException, RateLimitExceededException { + final SpamCheckResult> spamCheckResult = - spamChecker.checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_SEALED_SENDER, + spamChecker.checkForIndividualRecipientSpamGrpc( + story ? MessageType.INDIVIDUAL_STORY : MessageType.INDIVIDUAL_SEALED_SENDER, Optional.empty(), Optional.of(destination), destinationServiceIdentifier); @@ -110,7 +155,7 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me } try { - final int totalPayloadLength = request.getMessages().getMessagesMap().values().stream() + final int totalPayloadLength = messages.getMessagesMap().values().stream() .mapToInt(message -> message.getPayload().size()) .sum(); @@ -120,28 +165,23 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me throw e; } - final Map messagesByDeviceId = request.getMessages().getMessagesMap().entrySet() + final Map messagesByDeviceId = messages.getMessagesMap().entrySet() .stream() .collect(Collectors.toMap( entry -> DeviceIdUtil.validate(entry.getKey()), - entry -> { - final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() - .setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER) - .setClientTimestamp(request.getMessages().getTimestamp()) - .setServerTimestamp(clock.millis()) - .setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString()) - .setEphemeral(request.getEphemeral()) - .setUrgent(request.getUrgent()) - .setContent(entry.getValue().getPayload()); - - spamCheckResult.token() - .ifPresent(token -> envelopeBuilder.setReportSpamToken(ByteString.copyFrom(token))); - - return envelopeBuilder.build(); - } + entry -> MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER) + .setClientTimestamp(messages.getTimestamp()) + .setServerTimestamp(clock.millis()) + .setDestinationServiceId(destinationServiceIdentifier.toServiceIdentifierString()) + .setEphemeral(ephemeral) + .setUrgent(urgent) + .setStory(story) + .setContent(entry.getValue().getPayload()) + .build() )); - final Map registrationIdsByDeviceId = request.getMessages().getMessagesMap().entrySet().stream() + final Map registrationIdsByDeviceId = messages.getMessagesMap().entrySet().stream() .collect(Collectors.toMap( entry -> entry.getKey().byteValue(), entry -> entry.getValue().getRegistrationId())); @@ -167,36 +207,47 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me public SendMultiRecipientMessageResponse sendMultiRecipientMessage(final SendMultiRecipientMessageRequest request) throws StatusException { - final SealedSenderMultiRecipientMessage multiRecipientMessage; - - try { - multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(request.getMessage().getPayload().toByteArray()); - } catch (final InvalidMessageException | InvalidVersionException e) { - throw Status.INVALID_ARGUMENT.withCause(e).asException(); - } - - // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same - // recipient - { - final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID]; - - for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { - Arrays.fill(usedDeviceIds, false); - - for (final byte deviceId : recipient.getDevices()) { - if (usedDeviceIds[deviceId]) { - throw Status.INVALID_ARGUMENT.withDescription("Request contains repeated device entries").asException(); - } - - usedDeviceIds[deviceId] = true; - } - } - } + final SealedSenderMultiRecipientMessage multiRecipientMessage = + parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray()); groupSendTokenUtil.checkGroupSendToken(request.getGroupSendToken(), multiRecipientMessage.getRecipients().keySet()); + return sendMultiRecipientMessage(multiRecipientMessage, + request.getMessage().getTimestamp(), + request.getEphemeral(), + request.getUrgent(), + false); + } + + @Override + public SendMultiRecipientMessageResponse sendMultiRecipientStory(final SendMultiRecipientStoryRequest request) + throws StatusException { + + final SealedSenderMultiRecipientMessage multiRecipientMessage = + parseAndValidateMultiRecipientMessage(request.getMessage().getPayload().toByteArray()); + + return sendMultiRecipientMessage(multiRecipientMessage, + request.getMessage().getTimestamp(), + false, + request.getUrgent(), + true) + .toBuilder() + // Don't identify unresolved recipients for stories + .clearUnresolvedRecipients() + .build(); + } + + private SendMultiRecipientMessageResponse sendMultiRecipientMessage( + final SealedSenderMultiRecipientMessage multiRecipientMessage, + final long timestamp, + final boolean ephemeral, + final boolean urgent, + final boolean story) throws StatusException { + final SpamCheckResult> spamCheckResult = - spamChecker.checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_SEALED_SENDER); + spamChecker.checkForMultiRecipientSpamGrpc(story + ? MessageType.MULTI_RECIPIENT_STORY + : MessageType.MULTI_RECIPIENT_SEALED_SENDER); if (spamCheckResult.response().isPresent()) { return spamCheckResult.response().get().getResponseOrThrowStatus(); @@ -205,26 +256,15 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me // At this point, the caller has at least superficially provided the information needed to send a multi-recipient // message. Attempt to resolve the destination service identifiers to Signal accounts. final Map resolvedRecipients = - Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) - .flatMap(serviceIdAndRecipient -> { - final ServiceIdentifier serviceIdentifier = - ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey()); - - return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) - .flatMap(Mono::justOrEmpty) - .map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account)); - }, MAX_FETCH_ACCOUNT_CONCURRENCY) - .collectMap(Tuple2::getT1, Tuple2::getT2) - .blockOptional() - .orElse(Collections.emptyMap()); + resolveRecipients(multiRecipientMessage); try { messageSender.sendMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, - request.getMessage().getTimestamp(), - false, - request.getEphemeral(), - request.getUrgent(), + timestamp, + story, + ephemeral, + urgent, RequestAttributesUtil.getRawUserAgent().orElse(null)); final SendMultiRecipientMessageResponse.Builder responseBuilder = SendMultiRecipientMessageResponse.newBuilder(); @@ -255,6 +295,62 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me } } + private SealedSenderMultiRecipientMessage parseAndValidateMultiRecipientMessage( + final byte[] serializedMultiRecipientMessage) throws StatusException { + + final SealedSenderMultiRecipientMessage multiRecipientMessage; + + try { + multiRecipientMessage = SealedSenderMultiRecipientMessage.parse(serializedMultiRecipientMessage); + } catch (final InvalidMessageException | InvalidVersionException e) { + throw Status.INVALID_ARGUMENT.withCause(e).asException(); + } + + // Check that the request is well-formed and doesn't contain repeated entries for the same device for the same + // recipient + validateNoDuplicateDevices(multiRecipientMessage); + + return multiRecipientMessage; + } + + private void validateNoDuplicateDevices(final SealedSenderMultiRecipientMessage multiRecipientMessage) + throws StatusException { + + final boolean[] usedDeviceIds = new boolean[Device.MAXIMUM_DEVICE_ID + 1]; + + for (final SealedSenderMultiRecipientMessage.Recipient recipient : multiRecipientMessage.getRecipients().values()) { + if (recipient.getDevices().length == 1) { + // A recipient can't have repeated devices if they only have one device + continue; + } + + Arrays.fill(usedDeviceIds, false); + + for (final byte deviceId : recipient.getDevices()) { + if (usedDeviceIds[deviceId]) { + throw Status.INVALID_ARGUMENT.withDescription("Request contains repeated device entries").asException(); + } + + usedDeviceIds[deviceId] = true; + } + } + } + + private Map resolveRecipients(final SealedSenderMultiRecipientMessage multiRecipientMessage) { + return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) + .flatMap(serviceIdAndRecipient -> { + final ServiceIdentifier serviceIdentifier = + ServiceIdentifier.fromLibsignal(serviceIdAndRecipient.getKey()); + + return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .flatMap(Mono::justOrEmpty) + .map(account -> Tuples.of(serviceIdAndRecipient.getValue(), account)); + }, MAX_FETCH_ACCOUNT_CONCURRENCY) + .collectMap(Tuple2::getT1, Tuple2::getT2) + .blockOptional() + .orElse(Collections.emptyMap()); + } + private MismatchedDevices buildMismatchedDevices(final ServiceIdentifier serviceIdentifier, org.whispersystems.textsecuregcm.controllers.MismatchedDevices mismatchedDevices) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java index 4f03d9610..3410c9d60 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java @@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -49,13 +48,16 @@ import org.signal.chat.messages.MultiRecipientMismatchedDevices; import org.signal.chat.messages.SendMessageResponse; import org.signal.chat.messages.SendMultiRecipientMessageRequest; import org.signal.chat.messages.SendMultiRecipientMessageResponse; +import org.signal.chat.messages.SendMultiRecipientStoryRequest; import org.signal.chat.messages.SendSealedSenderMessageRequest; +import org.signal.chat.messages.SendStoryMessageRequest; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -124,6 +126,7 @@ class MessagesAnonymousGrpcServiceTest extends .thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); + when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter); doThrow(Status.UNAUTHENTICATED.asException()).when(groupSendTokenUtil) .checkGroupSendToken(any(), any(ServiceIdentifier.class)); @@ -188,6 +191,7 @@ class MessagesAnonymousGrpcServiceTest extends .setServerTimestamp(CLOCK.millis()) .setEphemeral(ephemeral) .setUrgent(urgent) + .setStory(false) .setContent(ByteString.copyFrom(payload)) .build(); @@ -516,7 +520,7 @@ class MessagesAnonymousGrpcServiceTest extends eq(false), eq(ephemeral), eq(urgent), - isNull()); + any()); } @Test @@ -789,4 +793,515 @@ class MessagesAnonymousGrpcServiceTest extends .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); } } + + @Nested + class SingleRecipientStory { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void sendStory(final boolean urgent) throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final byte[] payload = TestRandomUtil.nextBytes(128); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(payload)) + .build()); + + final SendMessageResponse response = + unauthenticatedServiceStub().sendStory(generateRequest(serviceIdentifier, urgent, messages)); + + assertEquals(SendMessageResponse.newBuilder().build(), response); + + final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.UNIDENTIFIED_SENDER) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setEphemeral(false) + .setUrgent(urgent) + .setStory(true) + .setContent(ByteString.copyFrom(payload)) + .build(); + + verify(messageSender).sendMessages(destinationAccount, + serviceIdentifier, + Map.of(deviceId, expectedEnvelope), + Map.of(deviceId, registrationId), + null); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) + .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + final SendMessageResponse response = unauthenticatedServiceStub().sendStory( + generateRequest(serviceIdentifier, false, messages)); + + final SendMessageResponse expectedResponse = SendMessageResponse.newBuilder() + .setMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .addMissingDevices(missingDeviceId) + .addStaleDevices(staleDeviceId) + .addExtraDevices(extraDeviceId) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void destinationNotFound() throws MessageTooLargeException, MismatchedDevicesException { + final Map messages = + Map.of(Device.PRIMARY_ID, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(7) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + final SendMessageResponse response = unauthenticatedServiceStub().sendStory( + generateRequest(new AciServiceIdentifier(UUID.randomUUID()), true, messages)); + + assertEquals(SendMessageResponse.newBuilder().build(), response); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void rateLimited() throws RateLimitExceededException, MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + when(destinationAccount.getIdentifier(IdentityType.ACI)).thenReturn(serviceIdentifier.uuid()); + + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Duration retryDuration = Duration.ofHours(7); + doThrow(new RateLimitExceededException(retryDuration)).when(rateLimiter).validate(eq(serviceIdentifier.uuid())); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertRateLimitExceeded(retryDuration, + () -> unauthenticatedServiceStub().sendStory(generateRequest(serviceIdentifier, true, messages))); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = Map.of( + staleDeviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(Device.PRIMARY_ID) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusInvalidArgument( + () -> unauthenticatedServiceStub().sendStory(generateRequest(serviceIdentifier, false, messages))); + } + + @Test + void spamWithStatus() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED, + () -> unauthenticatedServiceStub().sendStory(generateRequest(serviceIdentifier, true, messages))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_STORY, + Optional.empty(), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + @Test + void spamWithResponse() throws MessageTooLargeException, MismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + when(accountsManager.getByServiceIdentifier(serviceIdentifier)).thenReturn(Optional.of(destinationAccount)); + + final Map messages = + Map.of(deviceId, IndividualRecipientMessageBundle.Message.newBuilder() + .setRegistrationId(registrationId) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()); + + final SendMessageResponse response = SendMessageResponse.newBuilder() + .setChallengeRequired(ChallengeRequired.newBuilder() + .addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA)) + .build(); + + when(spamChecker.checkForIndividualRecipientSpamGrpc(any(), any(), any(), any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty())); + + assertEquals(response, unauthenticatedServiceStub().sendStory( + generateRequest(serviceIdentifier, true, messages))); + + verify(spamChecker).checkForIndividualRecipientSpamGrpc(MessageType.INDIVIDUAL_STORY, + Optional.empty(), + Optional.of(destinationAccount), + serviceIdentifier); + + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + } + + private static SendStoryMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier, + final boolean urgent, + final Map messages) { + + final IndividualRecipientMessageBundle.Builder messageBundleBuilder = IndividualRecipientMessageBundle.newBuilder() + .setTimestamp(CLOCK.millis()); + + messages.forEach(messageBundleBuilder::putMessages); + + return SendStoryMessageRequest.newBuilder() + .setDestination(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .setMessages(messageBundleBuilder) + .setUrgent(urgent) + .build(); + } + } + + @Nested + class MultiRecipientStory { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void sendStory(final boolean urgent) throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account resolvedAccount = mock(Account.class); + when(resolvedAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(resolvedAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier resolvedServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final AciServiceIdentifier unresolvedServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(resolvedServiceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(resolvedAccount))); + + final TestRecipient resolvedRecipient = + new TestRecipient(resolvedServiceIdentifier, deviceId, registrationId, new byte[48]); + + final TestRecipient unresolvedRecipient = + new TestRecipient(unresolvedServiceIdentifier, Device.PRIMARY_ID, 1, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + resolvedRecipient, unresolvedRecipient)); + + final SendMultiRecipientStoryRequest request = SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(urgent) + .build(); + + assertEquals(SendMultiRecipientMessageResponse.newBuilder().build(), + unauthenticatedServiceStub().sendMultiRecipientStory(request)); + + verify(messageSender).sendMultiRecipientMessage(any(), + argThat(resolvedRecipients -> resolvedRecipients.containsValue(resolvedAccount)), + eq(CLOCK.millis()), + eq(true), + eq(false), + eq(urgent), + any()); + } + + @Test + void mismatchedDevices() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte missingDeviceId = Device.PRIMARY_ID; + final byte extraDeviceId = missingDeviceId + 1; + final byte staleDeviceId = extraDeviceId + 1; + + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(serviceIdentifier, staleDeviceId, 17, new byte[48]))); + + final SendMultiRecipientStoryRequest request = SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(true) + .build(); + + doThrow(new MultiRecipientMismatchedDevicesException(Map.of(serviceIdentifier, + new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( + Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId))))) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + final SendMultiRecipientMessageResponse response = + unauthenticatedServiceStub().sendMultiRecipientStory(request); + + final SendMultiRecipientMessageResponse expectedResponse = SendMultiRecipientMessageResponse.newBuilder() + .setMismatchedDevices(MultiRecipientMismatchedDevices.newBuilder() + .addMismatchedDevices(MismatchedDevices.newBuilder() + .setServiceIdentifier(ServiceIdentifierUtil.toGrpcServiceIdentifier(serviceIdentifier)) + .addMissingDevices(missingDeviceId) + .addExtraDevices(extraDeviceId) + .addStaleDevices(staleDeviceId) + .build()) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void badPayload() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientStory(SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) + .build()) + .build())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientMessage( + SendMultiRecipientMessageRequest.newBuilder().build())); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void repeatedRecipient() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final Device destinationDevice = DevicesHelper.createDevice(Device.PRIMARY_ID, CLOCK.millis(), 1); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = new TestRecipient(serviceIdentifier, Device.PRIMARY_ID, 1, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient, recipient)); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, () -> + unauthenticatedServiceStub().sendMultiRecipientStory(SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(true) + .build())); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void oversizedMessage() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final Account destinationAccount = mock(Account.class); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of( + new TestRecipient(serviceIdentifier, Device.PRIMARY_ID, 17, new byte[48]))); + + final SendMultiRecipientStoryRequest request = SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(true) + .build(); + + doThrow(new MessageTooLargeException()) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusInvalidArgument(() -> unauthenticatedServiceStub().sendMultiRecipientStory(request)); + } + + @Test + void spamWithStatus() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = + new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient)); + + final SendMultiRecipientStoryRequest request = SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(true) + .build(); + + when(spamChecker.checkForMultiRecipientSpamGrpc(any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withStatus(Status.RESOURCE_EXHAUSTED)), Optional.empty())); + + //noinspection ResultOfMethodCallIgnored + GrpcTestUtils.assertStatusException(Status.RESOURCE_EXHAUSTED, + () -> unauthenticatedServiceStub().sendMultiRecipientStory(request)); + + verify(spamChecker).checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_STORY); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + + @Test + void spamWithResponse() throws MessageTooLargeException, MultiRecipientMismatchedDevicesException { + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 7; + + final Device destinationDevice = DevicesHelper.createDevice(deviceId, CLOCK.millis(), registrationId); + + final Account destinationAccount = mock(Account.class); + when(destinationAccount.getDevices()).thenReturn(List.of(destinationDevice)); + when(destinationAccount.getDevice(deviceId)).thenReturn(Optional.of(destinationDevice)); + + final AciServiceIdentifier serviceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + + when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(destinationAccount))); + + final TestRecipient recipient = + new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]); + + final byte[] payload = MultiRecipientMessageHelper.generateMultiRecipientMessage(List.of(recipient)); + + final SendMultiRecipientStoryRequest request = SendMultiRecipientStoryRequest.newBuilder() + .setMessage(MultiRecipientMessage.newBuilder() + .setTimestamp(CLOCK.millis()) + .setPayload(ByteString.copyFrom(payload)) + .build()) + .setUrgent(true) + .build(); + + final SendMultiRecipientMessageResponse response = SendMultiRecipientMessageResponse.newBuilder() + .setChallengeRequired(ChallengeRequired.newBuilder() + .addChallengeOptions(ChallengeRequired.ChallengeType.CAPTCHA)) + .build(); + + when(spamChecker.checkForMultiRecipientSpamGrpc(any())) + .thenReturn(new SpamCheckResult<>(Optional.of(GrpcResponse.withResponse(response)), Optional.empty())); + + assertEquals(response, unauthenticatedServiceStub().sendMultiRecipientStory(request)); + + verify(spamChecker).checkForMultiRecipientSpamGrpc(MessageType.MULTI_RECIPIENT_STORY); + + verify(messageSender, never()) + .sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + } + } }