From c6689ca07ac10930a73dad4e9e2f000a3a9649d5 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 7 Apr 2025 09:15:39 -0400 Subject: [PATCH] Internalize destination device list/registration ID checks in `MessageSender` --- .../controllers/AccountControllerV2.java | 56 ++-- .../controllers/MessageController.java | 138 ++++----- .../controllers/MismatchedDevices.java | 11 + .../MismatchedDevicesException.java | 18 +- ...tiRecipientMismatchedDevicesException.java | 24 ++ .../controllers/StaleDevicesException.java | 22 -- .../entities/AccountMismatchedDevices.java | 2 +- .../entities/AccountStaleDevices.java | 2 +- .../entities/MismatchedDevices.java | 20 -- .../entities/MismatchedDevicesResponse.java | 20 ++ ...Devices.java => StaleDevicesResponse.java} | 8 +- .../textsecuregcm/push/MessageSender.java | 129 ++++++++- .../textsecuregcm/push/ReceiptSender.java | 16 +- .../storage/AccountsManager.java | 38 ++- .../storage/ChangeNumberManager.java | 50 ++-- .../util/DestinationDeviceValidator.java | 108 ------- .../controllers/MessageControllerTest.java | 147 ++++++---- .../textsecuregcm/push/MessageSenderTest.java | 209 +++++++++++++- .../storage/AccountsManagerTest.java | 46 +++ .../storage/ChangeNumberManagerTest.java | 93 +----- .../util/DestinationDeviceValidatorTest.java | 273 ------------------ 21 files changed, 675 insertions(+), 755 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevices.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevicesResponse.java rename service/src/main/java/org/whispersystems/textsecuregcm/entities/{StaleDevices.java => StaleDevicesResponse.java} (53%) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java index ac49fe873..d4c375484 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; -import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest; import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; -import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.MessageTooLargeException; @@ -93,8 +93,8 @@ public class AccountControllerV2 { @ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password") - @ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) - @ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class))) + @ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class))) + @ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class))) @ApiResponse(responseCode = "413", description = "One or more device messages was too large") @ApiResponse(responseCode = "422", description = "The request did not pass validation") @ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class))) @@ -150,16 +150,18 @@ public class AccountControllerV2 { return AccountIdentityResponseBuilder.fromAccount(updatedAccount); } catch (MismatchedDevicesException e) { - throw new WebApplicationException(Response.status(409) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity(new MismatchedDevices(e.getMissingDevices(), - e.getExtraDevices())) - .build()); - } catch (StaleDevicesException e) { - throw new WebApplicationException(Response.status(410) - .type(MediaType.APPLICATION_JSON) - .entity(new StaleDevices(e.getStaleDevices())) - .build()); + if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds())) + .build()); + } else { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(), + e.getMismatchedDevices().extraDeviceIds())) + .build()); + } } catch (IllegalArgumentException e) { throw new BadRequestException(e); } catch (MessageTooLargeException e) { @@ -178,9 +180,9 @@ public class AccountControllerV2 { @ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.") @ApiResponse(responseCode = "422", description = "The request body failed validation.") @ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.", - content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) + content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class))) @ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.", - content = @Content(schema = @Schema(implementation = StaleDevices.class))) + content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class))) @ApiResponse(responseCode = "413", description = "One or more device messages was too large") public AccountIdentityResponse distributePhoneNumberIdentityKeys( @Mutable @Auth final AuthenticatedDevice authenticatedDevice, @@ -207,16 +209,18 @@ public class AccountControllerV2 { return AccountIdentityResponseBuilder.fromAccount(updatedAccount); } catch (MismatchedDevicesException e) { - throw new WebApplicationException(Response.status(409) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity(new MismatchedDevices(e.getMissingDevices(), - e.getExtraDevices())) - .build()); - } catch (StaleDevicesException e) { - throw new WebApplicationException(Response.status(410) - .type(MediaType.APPLICATION_JSON) - .entity(new StaleDevices(e.getStaleDevices())) - .build()); + if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds())) + .build()); + } else { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(), + e.getMismatchedDevices().extraDeviceIds())) + .build()); + } } catch (IllegalArgumentException e) { throw new BadRequestException(e); } catch (MessageTooLargeException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index a952651fc..231a570cf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response.Status; import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.UUID; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -62,7 +60,6 @@ import javax.annotation.Nullable; import org.glassfish.jersey.server.ManagedAsync; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.ServiceId; -import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair; @@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type; -import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SpamReport; -import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; @@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; -import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; @@ -245,10 +241,10 @@ public class MessageController { description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user") @ApiResponse( responseCode = "409", description = "Incorrect set of devices supplied for recipient", - content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) + content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class))) @ApiResponse( responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", - content = @Content(schema = @Schema(implementation = StaleDevices.class))) + content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class))) @ApiResponse( responseCode="428", description="The sender should complete a challenge before proceeding") @@ -381,14 +377,6 @@ public class MessageController { rateLimiters.getStoriesLimiter().validate(destination.getUuid()); } - final Set excludedDeviceIds; - - if (isSyncMessage) { - excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId()); - } else { - excludedDeviceIds = Collections.emptySet(); - } - final Map messagesByDeviceId = messages.messages().stream() .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> { try { @@ -407,15 +395,8 @@ public class MessageController { } })); - DestinationDeviceValidator.validateCompleteDeviceList(destination, - messagesByDeviceId.keySet(), - excludedDeviceIds); - - DestinationDeviceValidator.validateRegistrationIds(destination, - messages.messages(), - IncomingMessage::destinationDeviceId, - IncomingMessage::destinationRegistrationId, - destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid())); + final Map registrationIdsByDeviceId = messages.messages().stream() + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); final String authType; if (SENDER_TYPE_IDENTIFIED.equals(senderType)) { @@ -428,7 +409,7 @@ public class MessageController { authType = AUTH_TYPE_ACCESS_KEY; } - messageSender.sendMessages(destination, messagesByDeviceId); + messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId); Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE), @@ -440,16 +421,18 @@ public class MessageController { return Response.ok(new SendMessageResponse(needsSync)).build(); } catch (final MismatchedDevicesException e) { - throw new WebApplicationException(Response.status(409) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity(new MismatchedDevices(e.getMissingDevices(), - e.getExtraDevices())) - .build()); - } catch (final StaleDevicesException e) { - throw new WebApplicationException(Response.status(410) - .type(MediaType.APPLICATION_JSON) - .entity(new StaleDevices(e.getStaleDevices())) - .build()); + if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds())) + .build()); + } else { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(), + e.getMismatchedDevices().extraDeviceIds())) + .build()); + } } } finally { sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME) @@ -622,57 +605,6 @@ public class MessageController { } } - final Collection accountMismatchedDevices = new ArrayList<>(); - final Collection accountStaleDevices = new ArrayList<>(); - - multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> { - if (!resolvedRecipients.containsKey(recipient)) { - // When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We - // can just skip them. - return; - } - - final Account account = resolvedRecipients.get(recipient); - - try { - final Map deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds() - .collect(Collectors.toMap(Pair::first, Pair::second)); - - DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(), - Collections.emptySet()); - - DestinationDeviceValidator.validateRegistrationIds( - account, - deviceIdsToRegistrationIds.entrySet(), - Map.Entry::getKey, - e -> Integer.valueOf(e.getValue()), - serviceId instanceof ServiceId.Pni); - } catch (final MismatchedDevicesException e) { - accountMismatchedDevices.add( - new AccountMismatchedDevices( - ServiceIdentifier.fromLibsignal(serviceId), - new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); - } catch (final StaleDevicesException e) { - accountStaleDevices.add( - new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices()))); - } - }); - - if (!accountMismatchedDevices.isEmpty()) { - return Response - .status(409) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity(accountMismatchedDevices) - .build(); - } - if (!accountStaleDevices.isEmpty()) { - return Response - .status(410) - .type(MediaType.APPLICATION_JSON) - .entity(accountStaleDevices) - .build(); - } - final String authType; if (isStory) { authType = AUTH_TYPE_STORY; @@ -731,6 +663,38 @@ public class MessageController { } catch (ExecutionException e) { logger.error("partial failure while delivering multi-recipient messages", e.getCause()); throw new InternalServerErrorException("failure during delivery"); + } catch (MultiRecipientMismatchedDevicesException e) { + final List accountMismatchedDevices = + e.getMismatchedDevicesByServiceIdentifier().entrySet().stream() + .filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty()) + .map(entry -> new AccountMismatchedDevices(entry.getKey(), + new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds()))) + .toList(); + + if (!accountMismatchedDevices.isEmpty()) { + return Response + .status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(accountMismatchedDevices) + .build(); + } + + final List accountStaleDevices = + e.getMismatchedDevicesByServiceIdentifier().entrySet().stream() + .filter(entry -> !entry.getValue().staleDeviceIds().isEmpty()) + .map(entry -> new AccountStaleDevices(entry.getKey(), + new StaleDevicesResponse(entry.getValue().staleDeviceIds()))) + .toList(); + + if (!accountStaleDevices.isEmpty()) { + return Response + .status(410) + .type(MediaType.APPLICATION_JSON) + .entity(accountStaleDevices) + .build(); + } + + throw new RuntimeException(e); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevices.java new file mode 100644 index 000000000..77a6a85a2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevices.java @@ -0,0 +1,11 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import java.util.Set; + +public record MismatchedDevices(Set missingDeviceIds, Set extraDeviceIds, Set staleDeviceIds) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java index ba79d3120..272ee9859 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java @@ -5,23 +5,15 @@ package org.whispersystems.textsecuregcm.controllers; -import java.util.List; - public class MismatchedDevicesException extends Exception { - private final List missingDevices; - private final List extraDevices; + private final MismatchedDevices mismatchedDevices; - public MismatchedDevicesException(List missingDevices, List extraDevices) { - this.missingDevices = missingDevices; - this.extraDevices = extraDevices; + public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) { + this.mismatchedDevices = mismatchedDevices; } - public List getMissingDevices() { - return missingDevices; - } - - public List getExtraDevices() { - return extraDevices; + public MismatchedDevices getMismatchedDevices() { + return mismatchedDevices; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java new file mode 100644 index 000000000..65a054725 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import java.util.Map; + +public class MultiRecipientMismatchedDevicesException extends Exception { + + private final Map mismatchedDevicesByServiceIdentifier; + + public MultiRecipientMismatchedDevicesException( + final Map mismatchedDevicesByServiceIdentifier) { + + this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier; + } + + public Map getMismatchedDevicesByServiceIdentifier() { + return mismatchedDevicesByServiceIdentifier; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java deleted file mode 100644 index 484c9f9ef..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.controllers; - -import java.util.List; - - -public class StaleDevicesException extends Exception { - - private final List staleDevices; - - public StaleDevicesException(List staleDevices) { - this.staleDevices = staleDevices; - } - - public List getStaleDevices() { - return staleDevices; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java index 55e939ebe..adb8c0de8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java @@ -14,5 +14,5 @@ public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierA @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) ServiceIdentifier uuid, - MismatchedDevices devices) { + MismatchedDevicesResponse devices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java index 031046d8c..1d88a460c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java @@ -14,5 +14,5 @@ public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapte @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) ServiceIdentifier uuid, - StaleDevices devices) { + StaleDevicesResponse devices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java deleted file mode 100644 index 169e803d2..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.entities; - -import com.fasterxml.jackson.annotation.JsonProperty; -import io.swagger.v3.oas.annotations.media.Schema; - -import java.util.List; - -public record MismatchedDevices(@JsonProperty - @Schema(description = "Devices present on the account but absent in the request") - List missingDevices, - - @JsonProperty - @Schema(description = "Devices absent on the request but present in the account") - List extraDevices) { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevicesResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevicesResponse.java new file mode 100644 index 000000000..ee7a9fd57 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevicesResponse.java @@ -0,0 +1,20 @@ +/* + * Copyright 2013-2020 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; + +import java.util.Set; + +public record MismatchedDevicesResponse(@JsonProperty + @Schema(description = "Devices present on the account but absent in the request") + Set missingDevices, + + @JsonProperty + @Schema(description = "Devices absent on the request but present in the account") + Set extraDevices) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevicesResponse.java similarity index 53% rename from service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java rename to service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevicesResponse.java index d49e077b9..1c20c30c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevicesResponse.java @@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import io.swagger.v3.oas.annotations.media.Schema; -import java.util.List; +import java.util.Set; -public record StaleDevices(@JsonProperty - @Schema(description = "Devices that are no longer active") - List staleDevices) { +public record StaleDevicesResponse(@JsonProperty + @Schema(description = "Devices that are no longer active") + Set staleDevices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index f7e798e3a..402caa412 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting; import io.dropwizard.util.DataSize; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; -import java.util.Map; -import java.util.concurrent.CompletableFuture; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.signal.libsignal.protocol.util.Pair; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; @@ -58,6 +69,9 @@ public class MessageSender { public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes(); + @VisibleForTesting + static final byte NO_EXCLUDED_DEVICE_ID = -1; + public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) { this.messagesManager = messagesManager; this.pushNotificationManager = pushNotificationManager; @@ -68,23 +82,51 @@ public class MessageSender { * notification token and does not have an active connection to a Signal server, then this method will also send a * push notification to that device to announce the availability of new messages. * - * @param account the account to which to send messages + * @param destination the account to which to send messages + * @param destinationIdentifier the service identifier to which the messages are addressed * @param messagesByDeviceId a map of device IDs to message payloads + * @param registrationIdsByDeviceId a map of device IDs to device registration IDs */ - public void sendMessages(final Account account, final Map messagesByDeviceId) { - messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId) + public void sendMessages(final Account destination, + final ServiceIdentifier destinationIdentifier, + final Map messagesByDeviceId, + final Map registrationIdsByDeviceId) throws MismatchedDevicesException { + + if (messagesByDeviceId.isEmpty()) { + return; + } + + if (!destination.isIdentifiedBy(destinationIdentifier)) { + throw new IllegalArgumentException("Destination account not identified by destination service identifier"); + } + + final Envelope firstMessage = messagesByDeviceId.values().iterator().next(); + + final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) && + destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId())); + + final Optional maybeMismatchedDevices = getMismatchedDevices(destination, + destinationIdentifier, + registrationIdsByDeviceId, + isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID); + + if (maybeMismatchedDevices.isPresent()) { + throw new MismatchedDevicesException(maybeMismatchedDevices.get()); + } + + messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId) .forEach((deviceId, destinationPresent) -> { final Envelope message = messagesByDeviceId.get(deviceId); if (!destinationPresent && !message.getEphemeral()) { try { - pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent()); + pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent()); } catch (final NotPushRegisteredException ignored) { } } Metrics.counter(SEND_COUNTER_NAME, - CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"), + CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"), EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()), CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent), URGENT_TAG_NAME, String.valueOf(message.getUrgent()), @@ -98,6 +140,10 @@ public class MessageSender { * Sends messages to a group of recipients. If a destination device has a valid push notification token and does not * have an active connection to a Signal server, then this method will also send a push notification to that device to * announce the availability of new messages. + *

+ * This method sends messages to all resolved recipients. In some cases, a caller may not be able to resolve + * all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting + * the message if they require full resolution of all recipients, but some recipients could not be resolved. * * @param multiRecipientMessage the multi-recipient message to send to the given recipients * @param resolvedRecipients a map of recipients to resolved Signal accounts @@ -114,7 +160,31 @@ public class MessageSender { final long clientTimestamp, final boolean isStory, final boolean isEphemeral, - final boolean isUrgent) { + final boolean isUrgent) throws MultiRecipientMismatchedDevicesException { + + final Map mismatchedDevicesByServiceIdentifier = new HashMap<>(); + + multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> { + if (!resolvedRecipients.containsKey(recipient)) { + // Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run + // into an unresolved recipient here, just skip it. + return; + } + + final Account account = resolvedRecipients.get(recipient); + final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId); + + final Map registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds() + .collect(Collectors.toMap(Pair::first, pair -> (int) pair.second())); + + getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID) + .ifPresent(mismatchedDevices -> + mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices)); + }); + + if (!mismatchedDevicesByServiceIdentifier.isEmpty()) { + throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier); + } return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, isStory, isEphemeral, isUrgent) @@ -189,4 +259,47 @@ public class MessageSender { .increment(); } } + + @VisibleForTesting + static Optional getMismatchedDevices(final Account account, + final ServiceIdentifier serviceIdentifier, + final Map registrationIdsByDeviceId, + final byte excludedDeviceId) { + + final Set accountDeviceIds = account.getDevices().stream() + .map(Device::getId) + .filter(deviceId -> deviceId != excludedDeviceId) + .collect(Collectors.toSet()); + + final Set missingDeviceIds = new HashSet<>(accountDeviceIds); + missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet()); + + final Set extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet()); + extraDeviceIds.removeAll(accountDeviceIds); + + final Set staleDeviceIds = registrationIdsByDeviceId.entrySet().stream() + // Filter out device IDs that aren't associated with the given account + .filter(entry -> !extraDeviceIds.contains(entry.getKey())) + .filter(entry -> { + final byte deviceId = entry.getKey(); + final int registrationId = entry.getValue(); + + // We know the device must be present because we've already filtered out device IDs that aren't associated + // with the given account + final Device device = account.getDevice(deviceId).orElseThrow(); + + final int expectedRegistrationId = switch (serviceIdentifier.identityType()) { + case ACI -> device.getRegistrationId(); + case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId); + }; + + return registrationId != expectedRegistrationId; + }) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + + return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty()) + ? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds)) + : Optional.empty(); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index 09da97e82..5d1b270a8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -54,9 +55,20 @@ public class ReceiptSender { .setUrgent(false) .build(); + final Map messagesByDeviceId = destinationAccount.getDevices().stream() + .collect(Collectors.toMap(Device::getId, ignored -> message)); + + final Map registrationIdsByDeviceId = destinationAccount.getDevices().stream() + .collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) { + case ACI -> device.getRegistrationId(); + case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId); + })); + try { - messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream() - .collect(Collectors.toMap(Device::getId, ignored -> message))); + messageSender.sendMessages(destinationAccount, + destinationIdentifier, + messagesByDeviceId, + registrationIdsByDeviceId); } catch (final Exception e) { logger.warn("Could not send delivery receipt", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index d7481448e..5c2207076 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -37,10 +37,12 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; @@ -55,6 +57,7 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nullable; import javax.crypto.Mac; @@ -67,6 +70,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceInfo; @@ -83,7 +87,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException; -import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -780,24 +783,33 @@ public class AccountsManager extends RedisPubSubAdapter implemen } // Check that all including primary ID are in signed pre-keys - DestinationDeviceValidator.validateCompleteDeviceList( - account, - pniSignedPreKeys.keySet(), - Collections.emptySet()); + validateCompleteDeviceList(account, pniSignedPreKeys.keySet()); // Check that all including primary ID are in Pq pre-keys if (pniPqLastResortPreKeys != null) { - DestinationDeviceValidator.validateCompleteDeviceList( - account, - pniPqLastResortPreKeys.keySet(), - Collections.emptySet()); + validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet()); } // Check that all devices are accounted for in the map of new PNI registration IDs - DestinationDeviceValidator.validateCompleteDeviceList( - account, - pniRegistrationIds.keySet(), - Collections.emptySet()); + validateCompleteDeviceList(account, pniRegistrationIds.keySet()); + } + + @VisibleForTesting + static void validateCompleteDeviceList(final Account account, final Set deviceIds) throws MismatchedDevicesException { + final Set accountDeviceIds = account.getDevices().stream() + .map(Device::getId) + .collect(Collectors.toSet()); + + final Set missingDeviceIds = new HashSet<>(accountDeviceIds); + missingDeviceIds.removeAll(deviceIds); + + final Set extraDeviceIds = new HashSet<>(deviceIds); + extraDeviceIds.removeAll(accountDeviceIds); + + if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { + throw new MismatchedDevicesException( + new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet())); + } } public record UsernameReservation(Account account, byte[] reservedUsernameHash){} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 25c1f1534..d0d0cffce 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.storage; import com.google.protobuf.ByteString; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; @@ -15,15 +14,15 @@ import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageTooLargeException; -import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; public class ChangeNumberManager { @@ -45,12 +44,10 @@ public class ChangeNumberManager { @Nullable final List deviceMessages, @Nullable final Map pniRegistrationIds, @Nullable final String senderUserAgent) - throws InterruptedException, MismatchedDevicesException, StaleDevicesException, MessageTooLargeException { + throws InterruptedException, MismatchedDevicesException, MessageTooLargeException { - if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { - // AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds - validateDeviceMessages(account, deviceMessages); - } else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { + if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) || + ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds))) { throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null"); } @@ -84,9 +81,7 @@ public class ChangeNumberManager { @Nullable final Map devicePqLastResortPreKeys, final List deviceMessages, final Map pniRegistrationIds, - final String senderUserAgent) throws MismatchedDevicesException, StaleDevicesException, MessageTooLargeException { - - validateDeviceMessages(account, deviceMessages); + final String senderUserAgent) throws MismatchedDevicesException, MessageTooLargeException { // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb // write anyway. Linked devices can handle some wasted extra key rotations. @@ -97,26 +92,9 @@ public class ChangeNumberManager { return updatedAccount; } - private void validateDeviceMessages(final Account account, - final List deviceMessages) throws MismatchedDevicesException, StaleDevicesException { - // Check that all except primary ID are in device messages - DestinationDeviceValidator.validateCompleteDeviceList( - account, - deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()), - Set.of(Device.PRIMARY_ID)); - - // check that all sync messages are to the current registration ID for the matching device - DestinationDeviceValidator.validateRegistrationIds( - account, - deviceMessages, - IncomingMessage::destinationDeviceId, - IncomingMessage::destinationRegistrationId, - false); - } - private void sendDeviceMessages(final Account account, final List deviceMessages, - final String senderUserAgent) throws MessageTooLargeException { + final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { for (final IncomingMessage message : deviceMessages) { MessageSender.validateContentLength(message.content().length, @@ -128,20 +106,26 @@ public class ChangeNumberManager { try { final long serverTimestamp = System.currentTimeMillis(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); - messageSender.sendMessages(account, deviceMessages.stream() + final Map messagesByDeviceId = deviceMessages.stream() .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder() .setType(Envelope.Type.forNumber(message.type())) .setClientTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp) - .setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) + .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString()) .setContent(ByteString.copyFrom(message.content())) - .setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) + .setSourceServiceId(serviceIdentifier.toServiceIdentifierString()) .setSourceDevice(Device.PRIMARY_ID) .setUpdatedPni(account.getPhoneNumberIdentifier().toString()) .setUrgent(true) .setEphemeral(false) - .build()))); + .build())); + + final Map registrationIdsByDeviceId = account.getDevices().stream() + .collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); + + messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId); } catch (final RuntimeException e) { logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); throw e; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java deleted file mode 100644 index 4ba72f4b9..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2013-2022 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.util; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; - -public class DestinationDeviceValidator { - - /** - * @see #validateRegistrationIds(Account, Stream, boolean) - */ - public static void validateRegistrationIds(final Account account, - final Collection messages, - Function getDeviceId, - Function getRegistrationId, - boolean usePhoneNumberIdentity) throws StaleDevicesException { - - validateRegistrationIds(account, - messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))), - usePhoneNumberIdentity); - } - - /** - * Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID - * pairs in the given destination account. This method does not validate that all devices associated with the - * destination account are present in the given device ID/registration ID pairs. - * - * @param account the destination account against which to check the given device - * ID/registration ID pairs - * @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs - * @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device - * registration IDs associated with the account's PNI (if available); compare - * against the ACI-associated registration ID otherwise - * @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination - * account does not have a corresponding device or if the registration IDs do not match - */ - public static void validateRegistrationIds(final Account account, - final Stream> deviceIdAndRegistrationIdStream, - final boolean usePhoneNumberIdentity) throws StaleDevicesException { - - final List staleDevices = deviceIdAndRegistrationIdStream - .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) - .filter(deviceIdAndRegistrationId -> { - final byte deviceId = deviceIdAndRegistrationId.first(); - final int registrationId = deviceIdAndRegistrationId.second(); - boolean registrationIdMatches = account.getDevice(deviceId) - .map(device -> registrationId == (usePhoneNumberIdentity - ? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) - : device.getRegistrationId())) - .orElse(false); - return !registrationIdMatches; - }) - .map(Pair::first) - .collect(Collectors.toList()); - - if (!staleDevices.isEmpty()) { - throw new StaleDevicesException(staleDevices); - } - } - - /** - * Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the - * given destination account in preparation for sending those messages to the destination account. In general, the set - * of device IDs must exactly match the set of active devices associated with the destination account. When sending a - * "sync," message, though, the authenticated account is sending messages from one of their devices to all other - * devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}. - * - * @param account the destination account against which to check the given set of device IDs - * @param messageDeviceIds the set of device IDs to check against the destination account - * @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be - * present in the given set of device IDs (i.e. the device that is sending a sync message) - * @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with - * the destination account or is missing entries associated with the destination - * account - */ - public static void validateCompleteDeviceList(final Account account, - final Set messageDeviceIds, - final Set excludedDeviceIds) throws MismatchedDevicesException { - - final Set accountDeviceIds = account.getDevices().stream() - .map(Device::getId) - .filter(deviceId -> !excludedDeviceIds.contains(deviceId)) - .collect(Collectors.toSet()); - - final Set missingDeviceIds = new HashSet<>(accountDeviceIds); - missingDeviceIds.removeAll(messageDeviceIds); - - final Set extraDeviceIds = new HashSet<>(messageDeviceIds); - extraDeviceIds.removeAll(accountDeviceIds); - - if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { - throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds)); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index e258f322b..ce2524cb2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SpamReport; -import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; @@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import javax.annotation.Nullable; @ExtendWith(DropwizardExtensionsSupport.class) class MessageControllerTest { @@ -195,7 +197,7 @@ class MessageControllerTest { .build(); @BeforeEach - void setup() { + void setup() throws MultiRecipientMismatchedDevicesException { reset(pushNotificationScheduler); when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) @@ -287,7 +289,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), captor.capture()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -332,7 +334,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), captor.capture()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -357,7 +359,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), captor.capture()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -395,7 +397,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(200))); @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), captor.capture()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -434,7 +436,7 @@ class MessageControllerTest { assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); if (expectedResponse == 200) { @SuppressWarnings("unchecked") final ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), captor.capture()); + verify(messageSender).sendMessages(any(), any(), captor.capture(), any()); assertEquals(1, captor.getValue().size()); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); @@ -530,6 +532,9 @@ class MessageControllerTest { @Test void testMultiDeviceMissing() throws Exception { + doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet()))) + .when(messageSender).sendMessages(any(), any(), any(), any()); + try (final Response response = resources.getJerseyTest() .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) @@ -542,15 +547,16 @@ class MessageControllerTest { assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Body", - asJson(response.readEntity(MismatchedDevices.class)), + asJson(response.readEntity(MismatchedDevicesResponse.class)), is(equalTo(jsonFixture("fixtures/missing_device_response.json")))); - - verifyNoMoreInteractions(messageSender); } } @Test void testMultiDeviceExtra() throws Exception { + doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet()))) + .when(messageSender).sendMessages(any(), any(), any(), any()); + try (final Response response = resources.getJerseyTest() .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) @@ -563,10 +569,8 @@ class MessageControllerTest { assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Body", - asJson(response.readEntity(MismatchedDevices.class)), + asJson(response.readEntity(MismatchedDevicesResponse.class)), is(equalTo(jsonFixture("fixtures/missing_device_response2.json")))); - - verifyNoMoreInteractions(messageSender); } } @@ -602,7 +606,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -626,7 +630,7 @@ class MessageControllerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any()); assertEquals(3, envelopeCaptor.getValue().size()); @@ -648,12 +652,17 @@ class MessageControllerTest { assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); verify(messageSender).sendMessages(any(Account.class), - argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3)); + any(), + argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3), + any()); } } @Test void testRegistrationIdMismatch() throws Exception { + doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2)))) + .when(messageSender).sendMessages(any(), any(), any(), any()); + try (final Response response = resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .request() @@ -665,10 +674,8 @@ class MessageControllerTest { assertThat("Good Response Code", response.getStatus(), is(equalTo(410))); assertThat("Good Response Body", - asJson(response.readEntity(StaleDevices.class)), + asJson(response.readEntity(StaleDevicesResponse.class)), is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json")))); - - verifyNoMoreInteractions(messageSender); } } @@ -1078,7 +1085,7 @@ class MessageControllerTest { } @Test - void testValidateContentLength() { + void testValidateContentLength() throws MismatchedDevicesException { final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1); final byte[] contentBytes = new byte[contentLength]; Arrays.fill(contentBytes, (byte) 1); @@ -1095,7 +1102,7 @@ class MessageControllerTest { assertThat("Bad response", response.getStatus(), is(equalTo(413))); - verify(messageSender, never()).sendMessages(any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any()); } } @@ -1113,10 +1120,10 @@ class MessageControllerTest { if (expectOk) { assertEquals(200, response.getStatus()); - verify(messageSender).sendMessages(any(), any()); + verify(messageSender).sendMessages(any(), any(), any(), any()); } else { assertEquals(422, response.getStatus()); - verify(messageSender, never()).sendMessages(any(), any()); + verify(messageSender, never()).sendMessages(any(), any(), any(), any()); } } } @@ -1140,7 +1147,9 @@ class MessageControllerTest { final Optional maybeGroupSendToken, final int expectedStatus, final Set expectedResolvedAccounts, - final Set expectedUuids404) { + final Set expectedUuids404, + @Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException) + throws MultiRecipientMismatchedDevicesException { clock.pin(START_OF_DAY); @@ -1151,6 +1160,11 @@ class MessageControllerTest { when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))))); + if (mismatchedDevicesException != null) { + doThrow(mismatchedDevicesException) + .when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()); + } + final boolean ephemeral = true; final boolean urgent = false; @@ -1187,7 +1201,7 @@ class MessageControllerTest { assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404)); } - if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) { + if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) { verify(messageSender).sendMultiRecipientMessage(any(), argThat(resolvedRecipients -> new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)), @@ -1267,7 +1281,8 @@ class MessageControllerTest { Optional.empty(), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Multi-recipient message with combined UAKs", accountsByServiceIdentifier, @@ -1279,7 +1294,8 @@ class MessageControllerTest { Optional.empty(), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Multi-recipient message with group send endorsement", accountsByServiceIdentifier, @@ -1291,7 +1307,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Incorrect combined UAK", accountsByServiceIdentifier, @@ -1303,7 +1320,8 @@ class MessageControllerTest { Optional.empty(), 401, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Incorrect group send endorsement", accountsByServiceIdentifier, @@ -1317,7 +1335,8 @@ class MessageControllerTest { START_OF_DAY.plus(Duration.ofDays(1)))), 401, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), // Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if // provided @@ -1331,7 +1350,8 @@ class MessageControllerTest { Optional.empty(), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Story with group send endorsement", accountsByServiceIdentifier, @@ -1343,7 +1363,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Conflicting credentials", accountsByServiceIdentifier, @@ -1355,7 +1376,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("No credentials", accountsByServiceIdentifier, @@ -1367,7 +1389,8 @@ class MessageControllerTest { Optional.empty(), 401, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Oversized payload", accountsByServiceIdentifier, @@ -1383,7 +1406,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 413, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Negative timestamp", accountsByServiceIdentifier, @@ -1395,7 +1419,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Excessive timestamp", accountsByServiceIdentifier, @@ -1407,7 +1432,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Empty recipient list", accountsByServiceIdentifier, @@ -1421,7 +1447,8 @@ class MessageControllerTest { START_OF_DAY.plus(Duration.ofDays(1)))), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Story with empty recipient list", accountsByServiceIdentifier, @@ -1433,7 +1460,8 @@ class MessageControllerTest { Optional.empty(), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Duplicate recipient", accountsByServiceIdentifier, @@ -1447,7 +1475,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 400, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Missing account", Map.of(), @@ -1459,7 +1488,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 200, Collections.emptySet(), - Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))), + Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)), + null), Arguments.argumentSet("One missing and one existing account", Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount), @@ -1471,7 +1501,8 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 200, Set.of(singleDeviceAccount), - Set.of(new AciServiceIdentifier(multiDeviceAccountAci))), + Set.of(new AciServiceIdentifier(multiDeviceAccountAci)), + null), Arguments.argumentSet("Missing account for story", Map.of(), @@ -1483,7 +1514,8 @@ class MessageControllerTest { Optional.empty(), 200, Collections.emptySet(), - Set.of()), + Set.of(), + null), Arguments.argumentSet("One missing and one existing account for story", Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount), @@ -1495,7 +1527,8 @@ class MessageControllerTest { Optional.empty(), 200, Set.of(singleDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Missing device", accountsByServiceIdentifier, @@ -1509,7 +1542,9 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 409, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci), + new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))), Arguments.argumentSet("Extra device", accountsByServiceIdentifier, @@ -1525,7 +1560,9 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 409, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci), + new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))), Arguments.argumentSet("Stale registration ID", accountsByServiceIdentifier, @@ -1540,7 +1577,9 @@ class MessageControllerTest { Optional.of(groupSendEndorsement), 410, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci), + new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))), Arguments.argumentSet("Rate-limited story", accountsByServiceIdentifier, @@ -1552,7 +1591,8 @@ class MessageControllerTest { Optional.empty(), 429, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Story to PNI recipients", accountsByServiceIdentifier, @@ -1567,7 +1607,8 @@ class MessageControllerTest { Optional.empty(), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK", accountsByServiceIdentifier, @@ -1582,7 +1623,8 @@ class MessageControllerTest { Optional.empty(), 401, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()), + Set.of(), + null), Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement", accountsByServiceIdentifier, @@ -1599,7 +1641,8 @@ class MessageControllerTest { START_OF_DAY.plus(Duration.ofDays(1)))), 200, Set.of(singleDeviceAccount, multiDeviceAccount), - Set.of()) + Set.of(), + null) ); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 5282e98eb..55b698723 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -22,6 +22,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; @@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.signal.libsignal.protocol.InvalidMessageException; +import org.signal.libsignal.protocol.InvalidVersionException; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; +import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; +import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper; +import org.whispersystems.textsecuregcm.tests.util.TestRecipient; class MessageSenderTest { @@ -60,7 +73,9 @@ class MessageSenderTest { final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; final UUID accountIdentifier = UUID.randomUUID(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier); final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 17; final Account account = mock(Account.class); final Device device = mock(Device.class); @@ -71,7 +86,11 @@ class MessageSenderTest { when(account.getUuid()).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true); + when(account.getDevices()).thenReturn(List.of(device)); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(device.getId()).thenReturn(deviceId); + when(device.getRegistrationId()).thenReturn(registrationId); if (hasPushToken) { when(device.getApnId()).thenReturn("apns-token"); @@ -82,7 +101,10 @@ class MessageSenderTest { when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent)); - assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message))); + assertDoesNotThrow(() -> messageSender.sendMessages(account, + serviceIdentifier, + Map.of(device.getId(), message), + Map.of(device.getId(), registrationId))); final MessageProtos.Envelope expectedMessage = ephemeral ? message.toBuilder().setEphemeral(true).build() @@ -97,23 +119,61 @@ class MessageSenderTest { } } + @Test + void sendMessageMismatchedDevices() { + final UUID accountIdentifier = UUID.randomUUID(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier); + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 17; + + final Account account = mock(Account.class); + final Device device = mock(Device.class); + final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build(); + + when(account.getUuid()).thenReturn(accountIdentifier); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true); + when(account.getDevices()).thenReturn(List.of(device)); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + when(device.getId()).thenReturn(deviceId); + when(device.getRegistrationId()).thenReturn(registrationId); + when(device.getApnId()).thenReturn("apns-token"); + + final MismatchedDevicesException mismatchedDevicesException = + assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account, + serviceIdentifier, + Map.of(device.getId(), message), + Map.of(device.getId(), registrationId + 1))); + + assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)), + mismatchedDevicesException.getMismatchedDevices()); + } + @CartesianTest void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent, @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, @CartesianTest.Values(booleans = {true, false}) final boolean urgent, - @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException { + @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) + throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException { final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; final UUID accountIdentifier = UUID.randomUUID(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier); final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 17; final Account account = mock(Account.class); final Device device = mock(Device.class); when(account.getUuid()).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true); + when(account.getDevices()).thenReturn(List.of(device)); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(device.getId()).thenReturn(deviceId); + when(device.getRegistrationId()).thenReturn(registrationId); + when(device.getApnId()).thenReturn("apns-token"); if (hasPushToken) { when(device.getApnId()).thenReturn("apns-token"); @@ -125,12 +185,19 @@ class MessageSenderTest { when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) .thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent)))); - assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class), - Collections.emptyMap(), - System.currentTimeMillis(), - false, - ephemeral, - urgent) + final SealedSenderMultiRecipientMessage multiRecipientMessage = + SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage( + List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48])))); + + final SealedSenderMultiRecipientMessage.Recipient recipient = + multiRecipientMessage.getRecipients().values().iterator().next(); + + assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage, + Map.of(recipient, account), + System.currentTimeMillis(), + false, + ephemeral, + urgent) .join()); if (expectPushNotificationAttempt) { @@ -140,6 +207,49 @@ class MessageSenderTest { } } + @Test + void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException { + final UUID accountIdentifier = UUID.randomUUID(); + final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier); + final byte deviceId = Device.PRIMARY_ID; + final int registrationId = 17; + + final Account account = mock(Account.class); + final Device device = mock(Device.class); + + when(account.getUuid()).thenReturn(accountIdentifier); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true); + when(account.getDevices()).thenReturn(List.of(device)); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + when(device.getId()).thenReturn(deviceId); + when(device.getRegistrationId()).thenReturn(registrationId); + when(device.getApnId()).thenReturn("apns-token"); + + final SealedSenderMultiRecipientMessage multiRecipientMessage = + SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage( + List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48])))); + + final SealedSenderMultiRecipientMessage.Recipient recipient = + multiRecipientMessage.getRecipients().values().iterator().next(); + + when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) + .thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true)))); + + final MultiRecipientMismatchedDevicesException mismatchedDevicesException = + assertThrows(MultiRecipientMismatchedDevicesException.class, + () -> messageSender.sendMultiRecipientMessage(multiRecipientMessage, + Map.of(recipient, account), + System.currentTimeMillis(), + false, + false, + true) + .join()); + + assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))), + mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier()); + } + @ParameterizedTest @MethodSource void getDeliveryChannelName(final Device device, final String expectedChannelName) { @@ -183,4 +293,87 @@ class MessageSenderTest { assertDoesNotThrow(() -> MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null)); } + + @ParameterizedTest + @MethodSource + void getMismatchedDevices(final Account account, + final ServiceIdentifier serviceIdentifier, + final Map registrationIdsByDeviceId, + final byte excludedDeviceId, + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional expectedMismatchedDevices) { + + assertEquals(expectedMismatchedDevices, + MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId)); + } + + private static List getMismatchedDevices() { + final byte primaryDeviceId = Device.PRIMARY_ID; + final byte linkedDeviceId = primaryDeviceId + 1; + final byte extraDeviceId = linkedDeviceId + 1; + + final int primaryDeviceAciRegistrationId = 2; + final int primaryDevicePniRegistrationId = 3; + final int linkedDeviceAciRegistrationId = 5; + final int linkedDevicePniRegistrationId = 7; + + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(primaryDeviceId); + when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId); + when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId)); + + final Device linkedDevice = mock(Device.class); + when(linkedDevice.getId()).thenReturn(linkedDeviceId); + when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId); + when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId)); + + final Account account = mock(Account.class); + when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); + when(account.getDevice(anyByte())).thenReturn(Optional.empty()); + when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice)); + when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice)); + + final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); + final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID()); + + return List.of( + Arguments.argumentSet("Complete device list for ACI, no devices excluded", + account, + aciServiceIdentifier, + Map.of( + primaryDeviceId, primaryDeviceAciRegistrationId, + linkedDeviceId, linkedDeviceAciRegistrationId + ), + MessageSender.NO_EXCLUDED_DEVICE_ID, + Optional.empty()), + + Arguments.argumentSet("Complete device list for PNI, no devices excluded", + account, + pniServiceIdentifier, + Map.of( + primaryDeviceId, primaryDevicePniRegistrationId, + linkedDeviceId, linkedDevicePniRegistrationId + ), + MessageSender.NO_EXCLUDED_DEVICE_ID, + Optional.empty()), + + Arguments.argumentSet("Complete device list, device excluded", + account, + aciServiceIdentifier, + Map.of( + linkedDeviceId, linkedDeviceAciRegistrationId + ), + primaryDeviceId, + Optional.empty()), + + Arguments.argumentSet("Mismatched devices", + account, + aciServiceIdentifier, + Map.of( + linkedDeviceId, linkedDeviceAciRegistrationId + 1, + extraDeviceId, 17 + ), + MessageSender.NO_EXCLUDED_DEVICE_ID, + Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId)))) + ); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 704118d06..88de8f8f5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -60,10 +60,12 @@ import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; +import javax.annotation.Nullable; import javax.crypto.spec.SecretKeySpec; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; @@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevices; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; @@ -1705,4 +1708,47 @@ class AccountsManagerTest { Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp) ); } + + @ParameterizedTest + @MethodSource + void validateCompleteDeviceList(final Account account, final Set deviceIds, @Nullable final MismatchedDevicesException expectedException) { + final Executable validateCompleteDeviceListExecutable = + () -> AccountsManager.validateCompleteDeviceList(account, deviceIds); + + if (expectedException != null) { + final MismatchedDevicesException caughtException = + assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable); + + assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices()); + } else { + assertDoesNotThrow(validateCompleteDeviceListExecutable); + } + } + + private static List validateCompleteDeviceList() { + final byte deviceId = Device.PRIMARY_ID; + final byte extraDeviceId = deviceId + 1; + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + + final Account account = mock(Account.class); + when(account.getDevices()).thenReturn(List.of(device)); + + return List.of( + Arguments.of(account, Set.of(deviceId), null), + + Arguments.of(account, Set.of(deviceId, extraDeviceId), + new MismatchedDevicesException( + new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))), + + Arguments.of(account, Collections.emptySet(), + new MismatchedDevicesException( + new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))), + + Arguments.of(account, Set.of(extraDeviceId), + new MismatchedDevicesException( + new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet()))) + ); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 552cb640b..2d82edb0e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -105,7 +106,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); - verify(messageSender, never()).sendMessages(eq(account), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any()); } @Test @@ -119,7 +120,7 @@ public class ChangeNumberManagerTest { changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); - verify(messageSender, never()).sendMessages(eq(account), any()); + verify(messageSender, never()).sendMessages(eq(account), any(), any(), any()); } @Test @@ -159,7 +160,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -212,7 +213,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -263,7 +264,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -310,7 +311,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -359,7 +360,7 @@ public class ChangeNumberManagerTest { @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = ArgumentCaptor.forClass(Map.class); - verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); + verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any()); assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); @@ -372,82 +373,6 @@ public class ChangeNumberManagerTest { assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); } - @Test - void changeNumberMismatchedRegistrationId() { - final Account account = mock(Account.class); - when(account.getNumber()).thenReturn("+18005551234"); - - final List devices = new ArrayList<>(); - - for (byte i = 1; i <= 3; i++) { - final Device device = mock(Device.class); - when(device.getId()).thenReturn(i); - when(device.getRegistrationId()).thenReturn((int) i); - - devices.add(device); - when(account.getDevice(i)).thenReturn(Optional.of(device)); - } - - when(account.getDevices()).thenReturn(devices); - - final byte destinationDeviceId2 = 2; - final byte destinationDeviceId3 = 3; - final List messages = List.of( - new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)), - new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8))); - - final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); - - final Map preKeys = Map.of(Device.PRIMARY_ID, - KeysHelper.signedECPreKey(1, pniIdentityKeyPair), - destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), - destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); - final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47, - destinationDeviceId3, 89); - - assertThrows(StaleDevicesException.class, - () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null)); - } - - @Test - void updatePniKeysMismatchedRegistrationId() { - final Account account = mock(Account.class); - when(account.getNumber()).thenReturn("+18005551234"); - - final List devices = new ArrayList<>(); - - for (byte i = 1; i <= 3; i++) { - final Device device = mock(Device.class); - when(device.getId()).thenReturn(i); - when(device.getRegistrationId()).thenReturn((int) i); - - devices.add(device); - when(account.getDevice(i)).thenReturn(Optional.of(device)); - } - - when(account.getDevices()).thenReturn(devices); - - final byte destinationDeviceId2 = 2; - final byte destinationDeviceId3 = 3; - final List messages = List.of( - new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)), - new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8))); - - final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); - - final Map preKeys = Map.of(Device.PRIMARY_ID, - KeysHelper.signedECPreKey(1, pniIdentityKeyPair), - destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), - destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); - final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47, - destinationDeviceId3, 89); - - assertThrows(StaleDevicesException.class, - () -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null)); - } - @Test void changeNumberMissingData() { final Account account = mock(Account.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java deleted file mode 100644 index 9fc32b5ca..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Copyright 2013-2022 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.util; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.params.provider.Arguments.arguments; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.OptionalInt; -import java.util.Set; -import java.util.stream.Stream; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; -import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; - -@ExtendWith(DropwizardExtensionsSupport.class) -class DestinationDeviceValidatorTest { - - static Account mockAccountWithDeviceAndRegId(final Map registrationIdsByDeviceId) { - final Account account = mock(Account.class); - - registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { - final Device device = mock(Device.class); - when(device.getRegistrationId()).thenReturn(registrationId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - }); - - return account; - } - - static Stream validateRegistrationIdsSource() { - final byte id1 = 1; - final byte id2 = 2; - final byte id3 = 3; - return Stream.of( - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)), - Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF), - null), - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 42)), - Map.of(id1, 1492), - Set.of(id1)), - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 42)), - Map.of(id1, 42), - null), - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 42)), - Map.of(id1, 0), - null), - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)), - Map.of(id1, 0, id2, 42), - Set.of(id2)), - arguments( - mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)), - Map.of(id1, 41, id2, 257), - Set.of(id1, id2)) - ); - } - - @ParameterizedTest - @MethodSource("validateRegistrationIdsSource") - void testValidateRegistrationIds( - Account account, - Map registrationIdsByDeviceId, - Set expectedStaleDeviceIds) throws Exception { - if (expectedStaleDeviceIds != null) { - Assertions.assertThat(assertThrows(StaleDevicesException.class, - () -> DestinationDeviceValidator.validateRegistrationIds( - account, - registrationIdsByDeviceId.entrySet(), - Map.Entry::getKey, - Map.Entry::getValue, - false)) - .getStaleDevices()) - .hasSameElementsAs(expectedStaleDeviceIds); - } else { - DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(), - Map.Entry::getKey, Map.Entry::getValue, false); - } - } - - static Account mockAccountWithDeviceAndEnabled(final Map enabledStateByDeviceId) { - final Account account = mock(Account.class); - final List devices = new ArrayList<>(); - - enabledStateByDeviceId.forEach((deviceId, enabled) -> { - final Device device = mock(Device.class); - when(device.getId()).thenReturn(deviceId); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - - devices.add(device); - }); - - when(account.getDevices()).thenReturn(devices); - - return account; - } - - static Stream validateCompleteDeviceList() { - final byte id1 = 1; - final byte id2 = 2; - final byte id3 = 3; - - final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)); - - return Stream.of( - // Device IDs provided for all enabled devices - arguments( - account, - Set.of(id1, id3), - Set.of(id2), - null, - Collections.emptySet()), - - // Device ID provided for disabled device - arguments( - account, - Set.of(id1, id2, id3), - null, - null, - Collections.emptySet()), - - // Device ID omitted for enabled device - arguments( - account, - Set.of(id1), - Set.of(id2, id3), - null, - Collections.emptySet()), - - // Device ID included for disabled device, omitted for enabled device - arguments( - account, - Set.of(id1, id2), - Set.of(id3), - null, - Collections.emptySet()), - - // Device ID omitted for enabled device, included for device in excluded list - arguments( - account, - Set.of(id1), - Set.of(id2, id3), - Set.of(id1), - Set.of(id1) - ), - - // Device ID omitted for enabled device, included for disabled device, omitted for excluded device - arguments( - account, - Set.of(id2), - Set.of(id3), - null, - Set.of(id1) - ), - - // Device ID included for enabled device, omitted for excluded device - arguments( - account, - Set.of(id3), - Set.of(id2), - null, - Set.of(id1) - ) - ); - } - - @ParameterizedTest - @MethodSource - void validateCompleteDeviceList( - Account account, - Set deviceIds, - Collection expectedMissingDeviceIds, - Collection expectedExtraDeviceIds, - Set excludedDeviceIds) throws Exception { - - if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { - final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, - () -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds)); - if (expectedMissingDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) - .hasSameElementsAs(expectedMissingDeviceIds); - } - if (expectedExtraDeviceIds != null) { - Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); - } - } else { - DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds); - } - } - - @Test - void testDuplicateDeviceIds() { - final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17)); - try { - DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false); - Assertions.fail("duplicate devices should throw StaleDevicesException"); - } catch (StaleDevicesException e) { - Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID)); - } - } - - @Test - void testValidatePniRegistrationIds() { - final Device device = mock(Device.class); - when(device.getId()).thenReturn(Device.PRIMARY_ID); - - final Account account = mock(Account.class); - when(account.getDevices()).thenReturn(List.of(device)); - when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device)); - - final int aciRegistrationId = 17; - final int pniRegistrationId = 89; - final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId; - - when(device.getRegistrationId()).thenReturn(aciRegistrationId); - when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId)); - - assertDoesNotThrow( - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false)); - assertDoesNotThrow( - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)), - true)); - assertThrows(StaleDevicesException.class, - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), - true)); - assertThrows(StaleDevicesException.class, - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)), - false)); - - when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); - - assertDoesNotThrow( - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), - false)); - assertDoesNotThrow( - () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), - true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true)); - assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account, - Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false)); - } -}