diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index bb85cf0c1..3e72f4d52 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -436,11 +436,16 @@ public class MessageController { final Map registrationIdsByDeviceId = messages.messages().stream() .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); + final Optional syncMessageSenderDeviceId = messageType == MessageType.SYNC + ? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId()) + : Optional.empty(); + try { messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId, + syncMessageSenderDeviceId, userAgent); } catch (final MismatchedDevicesException e) { if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java index ada01114e..b6b5265cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcService.java @@ -187,7 +187,8 @@ public class MessagesAnonymousGrpcService extends SimpleMessagesAnonymousGrpc.Me destination, destinationServiceIdentifier, messagesByDeviceId, - registrationIdsByDeviceId); + registrationIdsByDeviceId, + Optional.empty()); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java index f40b870e1..1e1f73fd0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcHelper.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.grpc; import io.grpc.Status; import io.grpc.StatusException; import java.util.Map; +import java.util.Optional; import org.signal.chat.messages.MismatchedDevices; import org.signal.chat.messages.SendMessageResponse; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; @@ -31,6 +32,8 @@ public class MessagesGrpcHelper { * @param destinationServiceIdentifier the service identifier for the destination account * @param messagesByDeviceId a map of device IDs to message payloads * @param registrationIdsByDeviceId a map of device IDs to device registration IDs + * @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the + * caller's own account), contains the ID of the device that sent the message * * @return a response object to send to callers * @@ -42,13 +45,16 @@ public class MessagesGrpcHelper { final Account destination, final ServiceIdentifier destinationServiceIdentifier, final Map messagesByDeviceId, - final Map registrationIdsByDeviceId) throws StatusException, RateLimitExceededException { + final Map registrationIdsByDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional syncMessageSenderDeviceId) + throws StatusException, RateLimitExceededException { try { messageSender.sendMessages(destination, destinationServiceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, + syncMessageSenderDeviceId, RequestAttributesUtil.getRawUserAgent().orElse(null)); return SEND_MESSAGE_SUCCESS_RESPONSE; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java index 342fa25a2..932329a8f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcService.java @@ -172,7 +172,8 @@ public class MessagesGrpcService extends SimpleMessagesGrpc.MessagesImplBase { destination, destinationServiceIdentifier, messagesByDeviceId, - registrationIdsByDeviceId); + registrationIdsByDeviceId, + messageType == MessageType.SYNC ? Optional.of(sender.deviceId()) : Optional.empty()); } private static MessageProtos.Envelope.Type getEnvelopeType(final AuthenticatedSenderMessageType type) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index da1f47ee2..f0aaa03cd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -13,7 +13,6 @@ import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -21,6 +20,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.util.Pair; @@ -36,7 +36,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Util; -import javax.annotation.Nullable; /** * A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages, @@ -86,6 +85,8 @@ public class MessageSender { * @param destinationIdentifier the service identifier to which the messages are addressed * @param messagesByDeviceId a map of device IDs to message payloads * @param registrationIdsByDeviceId a map of device IDs to device registration IDs + * @param syncMessageSenderDeviceId if the message is a sync message (i.e. a message to other devices linked to the + * caller's own account), contains the ID of the device that sent the message * @param userAgent the User-Agent string for the sender; may be {@code null} if not known * * @throws MismatchedDevicesException if the given bundle of messages did not include a message for all required @@ -97,38 +98,55 @@ public class MessageSender { final ServiceIdentifier destinationIdentifier, final Map messagesByDeviceId, final Map registrationIdsByDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional syncMessageSenderDeviceId, @Nullable final String userAgent) throws MismatchedDevicesException, MessageTooLargeException { - if (messagesByDeviceId.isEmpty()) { - // TODO Simply return and don't throw an exception when iOS clients no longer depend on this behavior - throw new MismatchedDevicesException(new MismatchedDevices( - destination.getDevices().stream().map(Device::getId).collect(Collectors.toSet()), - Collections.emptySet(), - Collections.emptySet())); - } - if (!destination.isIdentifiedBy(destinationIdentifier)) { throw new IllegalArgumentException("Destination account not identified by destination service identifier"); } - final Envelope firstMessage = messagesByDeviceId.values().iterator().next(); + final boolean isSyncMessage; + final boolean isStory; + final byte excludedDeviceId; - final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) && - destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId())); + if (syncMessageSenderDeviceId.isPresent()) { + if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isBlank(message.getSourceServiceId()) || + !destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { - final boolean isStory = firstMessage.getStory(); + throw new IllegalArgumentException("Sync message sender device ID specified, but one or more messages are not addressed to sender"); + } - validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent); + isSyncMessage = true; + isStory = false; + excludedDeviceId = syncMessageSenderDeviceId.get(); + } else { + if (messagesByDeviceId.values().stream().anyMatch(message -> StringUtils.isNotBlank(message.getSourceServiceId()) && + destination.isIdentifiedBy(ServiceIdentifier.valueOf(message.getSourceServiceId())))) { + + throw new IllegalArgumentException("Sync message sender device ID not specified, but one or more messages are addressed to sender"); + } + + isSyncMessage = false; + excludedDeviceId = NO_EXCLUDED_DEVICE_ID; + + // It's technically possible that the caller tried to send a story with an empty message list, in which case we'd + // incorrectly set this to `false`, but the mismatched device check will throw an exception before that matters. + isStory = messagesByDeviceId.values().stream().findAny() + .map(Envelope::getStory) + .orElse(false); + } final Optional maybeMismatchedDevices = getMismatchedDevices(destination, destinationIdentifier, registrationIdsByDeviceId, - isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID); + excludedDeviceId); if (maybeMismatchedDevices.isPresent()) { throw new MismatchedDevicesException(maybeMismatchedDevices.get()); } + validateIndividualMessageContentLength(messagesByDeviceId.values(), isSyncMessage, isStory, userAgent); + messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId) .forEach((deviceId, destinationPresent) -> { final Envelope message = messagesByDeviceId.get(deviceId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index 9f2c3320d..3c19dfaa9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.push; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -70,6 +71,7 @@ public class ReceiptSender { destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId, + Optional.empty(), UserAgentTagUtil.SERVER_UA); } catch (final Exception e) { logger.warn("Could not send delivery receipt", e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index a77227079..2a0700f63 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage; import com.google.protobuf.ByteString; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; @@ -116,7 +117,12 @@ public class ChangeNumberManager { final Map registrationIdsByDeviceId = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); - messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId, senderUserAgent); + messageSender.sendMessages(account, + serviceIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId, + Optional.of(Device.PRIMARY_ID), + senderUserAgent); } catch (final RuntimeException e) { logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); throw e; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 6a5514cfa..be9b4878a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -292,7 +292,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -319,7 +319,19 @@ class MessageControllerTest { IncomingMessageList.class), MediaType.APPLICATION_JSON_TYPE))) { - assertThat(response.getStatus(), is(equalTo(sendToPni ? 403 : 200))); + if (sendToPni) { + assertThat(response.getStatus(), is(equalTo(403))); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); + } else { + assertThat(response.getStatus(), is(equalTo(200))); + + verify(messageSender).sendMessages(any(), + eq(serviceIdentifier), + any(), + any(), + eq(Optional.of(Device.PRIMARY_ID)), + any()); + } } } @@ -337,7 +349,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -362,7 +374,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -400,7 +412,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -439,7 +451,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); if (expectedResponse == 200) { @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any(), eq(Optional.empty()), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -536,7 +548,7 @@ class MessageControllerTest { @Test void testMultiDeviceMissing() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet()))) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() @@ -558,7 +570,7 @@ class MessageControllerTest { @Test void testMultiDeviceExtra() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet()))) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() @@ -609,7 +621,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -633,7 +645,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any(), eq(Optional.empty()), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -658,6 +670,7 @@ class MessageControllerTest { any(), argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3), any(), + eq(Optional.empty()), any()); } } @@ -665,7 +678,7 @@ class MessageControllerTest { @Test void testRegistrationIdMismatch() throws Exception { doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2)))) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) @@ -1090,7 +1103,7 @@ class MessageControllerTest { @Test void testValidateContentLength() throws MismatchedDevicesException, MessageTooLargeException, IOException { - doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any()); + doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); try (final Response response = resources.getJerseyTest() @@ -1119,10 +1132,10 @@ class MessageControllerTest { if (expectOk) { assertEquals(200, response.getStatus()); - verify(messageSender).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); } else { assertEquals(422, response.getStatus()); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java index a302d9b7b..9549c9a06 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesAnonymousGrpcServiceTest.java @@ -215,6 +215,7 @@ class MessagesAnonymousGrpcServiceTest extends serviceIdentifier, Map.of(deviceId, expectedEnvelopeBuilder.build()), Map.of(deviceId, registrationId), + Optional.empty(), null); } @@ -238,7 +239,7 @@ class MessagesAnonymousGrpcServiceTest extends doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); final SendMessageResponse response = unauthenticatedServiceStub().sendSingleRecipientMessage( generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null)); @@ -290,7 +291,7 @@ class MessagesAnonymousGrpcServiceTest extends useUak ? incorrectUnidentifiedAccessKey : null, useUak ? null : incorrectGroupSendToken))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -308,7 +309,7 @@ class MessagesAnonymousGrpcServiceTest extends () -> unauthenticatedServiceStub().sendSingleRecipientMessage( generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -341,7 +342,7 @@ class MessagesAnonymousGrpcServiceTest extends () -> unauthenticatedServiceStub().sendSingleRecipientMessage( generateRequest(serviceIdentifier, false, true, messages, UNIDENTIFIED_ACCESS_KEY, null))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString()); } @@ -364,7 +365,7 @@ class MessagesAnonymousGrpcServiceTest extends .build()); doThrow(new MessageTooLargeException()) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusException(Status.INVALID_ARGUMENT, @@ -406,7 +407,7 @@ class MessagesAnonymousGrpcServiceTest extends Optional.of(destinationAccount), serviceIdentifier); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -446,7 +447,7 @@ class MessagesAnonymousGrpcServiceTest extends Optional.of(destinationAccount), serviceIdentifier); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } private static SendSealedSenderMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier, @@ -873,6 +874,7 @@ class MessagesAnonymousGrpcServiceTest extends serviceIdentifier, Map.of(deviceId, expectedEnvelopeBuilder.build()), Map.of(deviceId, registrationId), + Optional.empty(), null); } @@ -896,7 +898,7 @@ class MessagesAnonymousGrpcServiceTest extends doThrow(new MismatchedDevicesException(new org.whispersystems.textsecuregcm.controllers.MismatchedDevices( Set.of(missingDeviceId), Set.of(extraDeviceId), Set.of(staleDeviceId)))) - .when(messageSender).sendMessages(any(), any(), any(), any(), any()); + .when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); final SendMessageResponse response = unauthenticatedServiceStub().sendStory( generateRequest(serviceIdentifier, false, messages)); @@ -926,7 +928,7 @@ class MessagesAnonymousGrpcServiceTest extends assertEquals(SendMessageResponse.newBuilder().build(), response); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -957,7 +959,7 @@ class MessagesAnonymousGrpcServiceTest extends GrpcTestUtils.assertRateLimitExceeded(retryDuration, () -> unauthenticatedServiceStub().sendStory(generateRequest(serviceIdentifier, true, messages))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -978,7 +980,7 @@ class MessagesAnonymousGrpcServiceTest extends .setPayload(ByteString.copyFrom(TestRandomUtil.nextBytes(128))) .build()); - doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any()); + doThrow(new MessageTooLargeException()).when(messageSender).sendMessages(any(), any(), any(), any(), any(), any()); //noinspection ResultOfMethodCallIgnored GrpcTestUtils.assertStatusInvalidArgument( @@ -1017,7 +1019,7 @@ class MessagesAnonymousGrpcServiceTest extends Optional.of(destinationAccount), serviceIdentifier); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -1056,7 +1058,7 @@ class MessagesAnonymousGrpcServiceTest extends Optional.of(destinationAccount), serviceIdentifier); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } private static SendStoryMessageRequest generateRequest(final ServiceIdentifier serviceIdentifier, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java index ec21a19d3..86317bdc8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/MessagesGrpcServiceTest.java @@ -218,6 +218,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().sendMessage( generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); } @Test @@ -305,7 +306,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().sendMessage( generateRequest(serviceIdentifier, AuthenticatedSenderMessageType.DOUBLE_RATCHET, false, true, messages))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); verify(messageByteLimitEstimator).add(serviceIdentifier.uuid().toString()); } @@ -327,7 +328,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().sendSyncMessage( generateRequest(AuthenticatedSenderMessageType.DOUBLE_RATCHET, true, messages))); - verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any(), any(), any()); verify(messageByteLimitEstimator).add(AUTHENTICATED_ACI.toString()); } @@ -588,7 +590,7 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest messageSender.sendMessages(account, - new AciServiceIdentifier(UUID.randomUUID()), + serviceIdentifier, Collections.emptyMap(), Collections.emptyMap(), + Optional.empty(), + null)); + + assertDoesNotThrow(() -> messageSender.sendMessages(account, + serviceIdentifier, + Collections.emptyMap(), + Collections.emptyMap(), + Optional.of(Device.PRIMARY_ID), null)); } + + @Test + void sendSyncMessageMismatchedAddressing() { + final UUID accountIdentifier = UUID.randomUUID(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier); + final byte deviceId = Device.PRIMARY_ID; + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(accountIdentifier); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true); + + final Account nonSyncDestination = mock(Account.class); + when(nonSyncDestination.isIdentifiedBy(any())).thenReturn(true); + + assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(nonSyncDestination, + new AciServiceIdentifier(UUID.randomUUID()), + Map.of(deviceId, MessageProtos.Envelope.newBuilder().build()), + Map.of(deviceId, 17), + Optional.of(deviceId), + null), + "Should throw an IllegalArgumentException for inter-account messages with a sync message device ID"); + + assertThrows(IllegalArgumentException.class, () -> messageSender.sendMessages(account, + serviceIdentifier, + Map.of(deviceId, MessageProtos.Envelope.newBuilder() + .setSourceServiceId(serviceIdentifier.toServiceIdentifierString()) + .setSourceDevice(deviceId) + .build()), + Map.of(deviceId, 17), + Optional.empty(), + null), + "Should throw an IllegalArgumentException for self-addressed messages without a sync message device ID"); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 85004c317..9830a9e35 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -103,7 +103,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); - verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); } @Test @@ -117,7 +117,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); - verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any(), any(), any()); } @Test @@ -157,7 +157,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -210,7 +210,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -261,7 +261,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -308,7 +308,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -357,7 +357,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());