Internalize destination device list/registration ID checks in `MessageSender`

This commit is contained in:
Jon Chambers 2025-04-07 09:15:39 -04:00 committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 675 additions and 755 deletions

View File

@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
@ -93,8 +93,8 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password")
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class)))
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
@ApiResponse(responseCode = "422", description = "The request did not pass validation")
@ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class)))
@ -150,16 +150,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {
@ -178,9 +180,9 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.")
@ApiResponse(responseCode = "422", description = "The request body failed validation.")
@ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@ -207,16 +209,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
} catch (MessageTooLargeException e) {

View File

@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
@ -62,7 +60,6 @@ import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@ -245,10 +241,10 @@ public class MessageController {
description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user")
@ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for recipient",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(
responseCode="428",
description="The sender should complete a challenge before proceeding")
@ -381,14 +377,6 @@ public class MessageController {
rateLimiters.getStoriesLimiter().validate(destination.getUuid());
}
final Set<Byte> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
try {
@ -407,15 +395,8 @@ public class MessageController {
}
}));
DestinationDeviceValidator.validateCompleteDeviceList(destination,
messagesByDeviceId.keySet(),
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination,
messages.messages(),
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final String authType;
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
@ -428,7 +409,7 @@ public class MessageController {
authType = AUTH_TYPE_ACCESS_KEY;
}
messageSender.sendMessages(destination, messagesByDeviceId);
messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
@ -440,16 +421,18 @@ public class MessageController {
return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (final MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (final StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build());
} else {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
e.getMismatchedDevices().extraDeviceIds()))
.build());
}
}
} finally {
sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME)
@ -622,57 +605,6 @@ public class MessageController {
}
}
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
// can just skip them.
return;
}
final Account account = resolvedRecipients.get(recipient);
try {
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceIdsToRegistrationIds.entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
serviceId instanceof ServiceId.Pni);
} catch (final MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
ServiceIdentifier.fromLibsignal(serviceId),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (final StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
final String authType;
if (isStory) {
authType = AUTH_TYPE_STORY;
@ -731,6 +663,38 @@ public class MessageController {
} catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
throw new InternalServerErrorException("failure during delivery");
} catch (MultiRecipientMismatchedDevicesException e) {
final List<AccountMismatchedDevices> accountMismatchedDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty())
.map(entry -> new AccountMismatchedDevices(entry.getKey(),
new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds())))
.toList();
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
final List<AccountStaleDevices> accountStaleDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().staleDeviceIds().isEmpty())
.map(entry -> new AccountStaleDevices(entry.getKey(),
new StaleDevicesResponse(entry.getValue().staleDeviceIds())))
.toList();
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
throw new RuntimeException(e);
}
}

View File

@ -0,0 +1,11 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.Set;
public record MismatchedDevices(Set<Byte> missingDeviceIds, Set<Byte> extraDeviceIds, Set<Byte> staleDeviceIds) {
}

View File

@ -5,23 +5,15 @@
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class MismatchedDevicesException extends Exception {
private final List<Byte> missingDevices;
private final List<Byte> extraDevices;
private final MismatchedDevices mismatchedDevices;
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) {
this.missingDevices = missingDevices;
this.extraDevices = extraDevices;
public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) {
this.mismatchedDevices = mismatchedDevices;
}
public List<Byte> getMissingDevices() {
return missingDevices;
}
public List<Byte> getExtraDevices() {
return extraDevices;
public MismatchedDevices getMismatchedDevices() {
return mismatchedDevices;
}
}

View File

@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
public class MultiRecipientMismatchedDevicesException extends Exception {
private final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier;
public MultiRecipientMismatchedDevicesException(
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier) {
this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier;
}
public Map<ServiceIdentifier, MismatchedDevices> getMismatchedDevicesByServiceIdentifier() {
return mismatchedDevicesByServiceIdentifier;
}
}

View File

@ -1,22 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Exception {
private final List<Byte> staleDevices;
public StaleDevicesException(List<Byte> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Byte> getStaleDevices() {
return staleDevices;
}
}

View File

@ -14,5 +14,5 @@ public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierA
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
MismatchedDevices devices) {
MismatchedDevicesResponse devices) {
}

View File

@ -14,5 +14,5 @@ public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapte
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
StaleDevices devices) {
StaleDevicesResponse devices) {
}

View File

@ -1,20 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
public record MismatchedDevices(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
List<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
List<Byte> extraDevices) {
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Set;
public record MismatchedDevicesResponse(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
Set<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
Set<Byte> extraDevices) {
}

View File

@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
import java.util.Set;
public record StaleDevices(@JsonProperty
@Schema(description = "Devices that are no longer active")
List<Byte> staleDevices) {
public record StaleDevicesResponse(@JsonProperty
@Schema(description = "Devices that are no longer active")
Set<Byte> staleDevices) {
}

View File

@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
@ -58,6 +69,9 @@ public class MessageSender {
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
@ -68,23 +82,51 @@ public class MessageSender {
* notification token and does not have an active connection to a Signal server, then this method will also send a
* push notification to that device to announce the availability of new messages.
*
* @param account the account to which to send messages
* @param destination the account to which to send messages
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
*/
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
public void sendMessages(final Account destination,
final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId) throws MismatchedDevicesException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral()) {
try {
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) {
}
}
Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
@ -98,6 +140,10 @@ public class MessageSender {
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to
* announce the availability of new messages.
* <p>
* This method sends messages to all <em>resolved</em> recipients. In some cases, a caller may not be able to resolve
* all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting
* the message if they require full resolution of all recipients, but some recipients could not be resolved.
*
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
* @param resolvedRecipients a map of recipients to resolved Signal accounts
@ -114,7 +160,31 @@ public class MessageSender {
final long clientTimestamp,
final boolean isStory,
final boolean isEphemeral,
final boolean isUrgent) {
final boolean isUrgent) throws MultiRecipientMismatchedDevicesException {
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run
// into an unresolved recipient here, just skip it.
return;
}
final Account account = resolvedRecipients.get(recipient);
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId);
final Map<Byte, Integer> registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair::first, pair -> (int) pair.second()));
getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID)
.ifPresent(mismatchedDevices ->
mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices));
});
if (!mismatchedDevicesByServiceIdentifier.isEmpty()) {
throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier);
}
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
isStory, isEphemeral, isUrgent)
@ -189,4 +259,47 @@ public class MessageSender {
.increment();
}
}
@VisibleForTesting
static Optional<MismatchedDevices> getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId) {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != excludedDeviceId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet());
final Set<Byte> extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet());
extraDeviceIds.removeAll(accountDeviceIds);
final Set<Byte> staleDeviceIds = registrationIdsByDeviceId.entrySet().stream()
// Filter out device IDs that aren't associated with the given account
.filter(entry -> !extraDeviceIds.contains(entry.getKey()))
.filter(entry -> {
final byte deviceId = entry.getKey();
final int registrationId = entry.getValue();
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
};
return registrationId != expectedRegistrationId;
})
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty())
? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds))
: Optional.empty();
}
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.slf4j.Logger;
@ -54,9 +55,20 @@ public class ReceiptSender {
.setUrgent(false)
.build();
final Map<Byte, Envelope> messagesByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
}));
try {
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message)));
messageSender.sendMessages(destinationAccount,
destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
} catch (final Exception e) {
logger.warn("Could not send delivery receipt", e);
}

View File

@ -37,10 +37,12 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@ -55,6 +57,7 @@ import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.Mac;
@ -67,6 +70,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@ -83,7 +87,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -780,24 +783,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
// Check that all including primary ID are in signed pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
// Check that all including primary ID are in Pq pre-keys
if (pniPqLastResortPreKeys != null) {
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniPqLastResortPreKeys.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
}
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniRegistrationIds.keySet(),
Collections.emptySet());
validateCompleteDeviceList(account, pniRegistrationIds.keySet());
}
@VisibleForTesting
static void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(deviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(deviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
}
}
public record UsernameReservation(Account account, byte[] reservedUsernameHash){}

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
@ -15,15 +14,15 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
public class ChangeNumberManager {
@ -45,12 +44,10 @@ public class ChangeNumberManager {
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Byte, Integer> pniRegistrationIds,
@Nullable final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
// AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds
validateDeviceMessages(account, deviceMessages);
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds))) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
}
@ -84,9 +81,7 @@ public class ChangeNumberManager {
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Byte, Integer> pniRegistrationIds,
final String senderUserAgent) throws MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
validateDeviceMessages(account, deviceMessages);
final String senderUserAgent) throws MismatchedDevicesException, MessageTooLargeException {
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations.
@ -97,26 +92,9 @@ public class ChangeNumberManager {
return updatedAccount;
}
private void validateDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages) throws MismatchedDevicesException, StaleDevicesException {
// Check that all except primary ID are in device messages
DestinationDeviceValidator.validateCompleteDeviceList(
account,
deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()),
Set.of(Device.PRIMARY_ID));
// check that all sync messages are to the current registration ID for the matching device
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceMessages,
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
false);
}
private void sendDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages,
final String senderUserAgent) throws MessageTooLargeException {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
for (final IncomingMessage message : deviceMessages) {
MessageSender.validateContentLength(message.content().length,
@ -128,20 +106,26 @@ public class ChangeNumberManager {
try {
final long serverTimestamp = System.currentTimeMillis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
messageSender.sendMessages(account, deviceMessages.stream()
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type()))
.setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp)
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setContent(ByteString.copyFrom(message.content()))
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.setEphemeral(false)
.build())));
.build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
} catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e;

View File

@ -1,108 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator {
/**
* @see #validateRegistrationIds(Account, Stream, boolean)
*/
public static <T> void validateRegistrationIds(final Account account,
final Collection<T> messages,
Function<T, Byte> getDeviceId,
Function<T, Integer> getRegistrationId,
boolean usePhoneNumberIdentity) throws StaleDevicesException {
validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
usePhoneNumberIdentity);
}
/**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs.
*
* @param account the destination account against which to check the given device
* ID/registration ID pairs
* @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
* registration IDs associated with the account's PNI (if available); compare
* against the ACI-associated registration ID otherwise
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Stream<Pair<Byte, Integer>> deviceIdAndRegistrationIdStream,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Byte> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
final byte deviceId = deviceIdAndRegistrationId.first();
final int registrationId = deviceIdAndRegistrationId.second();
boolean registrationIdMatches = account.getDevice(deviceId)
.map(device -> registrationId == (usePhoneNumberIdentity
? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
: device.getRegistrationId()))
.orElse(false);
return !registrationIdMatches;
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
/**
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
* given destination account in preparation for sending those messages to the destination account. In general, the set
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
*
* @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message)
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Byte> messageDeviceIds,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
}
}
}

View File

@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import javax.annotation.Nullable;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@ -195,7 +197,7 @@ class MessageControllerTest {
.build();
@BeforeEach
void setup() {
void setup() throws MultiRecipientMismatchedDevicesException {
reset(pushNotificationScheduler);
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
@ -287,7 +289,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -332,7 +334,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -357,7 +359,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -395,7 +397,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -434,7 +436,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture());
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -530,6 +532,9 @@ class MessageControllerTest {
@Test
void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -542,15 +547,16 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@Test
void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -563,10 +569,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)),
asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@ -602,7 +606,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@ -626,7 +630,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size());
@ -648,12 +652,17 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(Account.class),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3));
any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any());
}
}
@Test
void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request()
@ -665,10 +674,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
assertThat("Good Response Body",
asJson(response.readEntity(StaleDevices.class)),
asJson(response.readEntity(StaleDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
verifyNoMoreInteractions(messageSender);
}
}
@ -1078,7 +1085,7 @@ class MessageControllerTest {
}
@Test
void testValidateContentLength() {
void testValidateContentLength() throws MismatchedDevicesException {
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1);
@ -1095,7 +1102,7 @@ class MessageControllerTest {
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
@ -1113,10 +1120,10 @@ class MessageControllerTest {
if (expectOk) {
assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any());
verify(messageSender).sendMessages(any(), any(), any(), any());
} else {
assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any());
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
}
}
}
@ -1140,7 +1147,9 @@ class MessageControllerTest {
final Optional<String> maybeGroupSendToken,
final int expectedStatus,
final Set<Account> expectedResolvedAccounts,
final Set<ServiceIdentifier> expectedUuids404) {
final Set<ServiceIdentifier> expectedUuids404,
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
throws MultiRecipientMismatchedDevicesException {
clock.pin(START_OF_DAY);
@ -1151,6 +1160,11 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
if (mismatchedDevicesException != null) {
doThrow(mismatchedDevicesException)
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
}
final boolean ephemeral = true;
final boolean urgent = false;
@ -1187,7 +1201,7 @@ class MessageControllerTest {
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
}
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) {
if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
verify(messageSender).sendMultiRecipientMessage(any(),
argThat(resolvedRecipients ->
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
@ -1267,7 +1281,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with combined UAKs",
accountsByServiceIdentifier,
@ -1279,7 +1294,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with group send endorsement",
accountsByServiceIdentifier,
@ -1291,7 +1307,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect combined UAK",
accountsByServiceIdentifier,
@ -1303,7 +1320,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Incorrect group send endorsement",
accountsByServiceIdentifier,
@ -1317,7 +1335,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
// provided
@ -1331,7 +1350,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with group send endorsement",
accountsByServiceIdentifier,
@ -1343,7 +1363,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Conflicting credentials",
accountsByServiceIdentifier,
@ -1355,7 +1376,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("No credentials",
accountsByServiceIdentifier,
@ -1367,7 +1389,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Oversized payload",
accountsByServiceIdentifier,
@ -1383,7 +1406,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
413,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Negative timestamp",
accountsByServiceIdentifier,
@ -1395,7 +1419,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Excessive timestamp",
accountsByServiceIdentifier,
@ -1407,7 +1432,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Empty recipient list",
accountsByServiceIdentifier,
@ -1421,7 +1447,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story with empty recipient list",
accountsByServiceIdentifier,
@ -1433,7 +1460,8 @@ class MessageControllerTest {
Optional.empty(),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Duplicate recipient",
accountsByServiceIdentifier,
@ -1447,7 +1475,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
400,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing account",
Map.of(),
@ -1459,7 +1488,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Collections.emptySet(),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("One missing and one existing account",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@ -1471,7 +1501,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
200,
Set.of(singleDeviceAccount),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("Missing account for story",
Map.of(),
@ -1483,7 +1514,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Collections.emptySet(),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("One missing and one existing account for story",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@ -1495,7 +1527,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Missing device",
accountsByServiceIdentifier,
@ -1509,7 +1542,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
Arguments.argumentSet("Extra device",
accountsByServiceIdentifier,
@ -1525,7 +1560,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
409,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
Arguments.argumentSet("Stale registration ID",
accountsByServiceIdentifier,
@ -1540,7 +1577,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement),
410,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
Arguments.argumentSet("Rate-limited story",
accountsByServiceIdentifier,
@ -1552,7 +1591,8 @@ class MessageControllerTest {
Optional.empty(),
429,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Story to PNI recipients",
accountsByServiceIdentifier,
@ -1567,7 +1607,8 @@ class MessageControllerTest {
Optional.empty(),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
accountsByServiceIdentifier,
@ -1582,7 +1623,8 @@ class MessageControllerTest {
Optional.empty(),
401,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()),
Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
accountsByServiceIdentifier,
@ -1599,7 +1641,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))),
200,
Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of())
Set.of(),
null)
);
}

View File

@ -22,6 +22,9 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
class MessageSenderTest {
@ -60,7 +73,9 @@ class MessageSenderTest {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
@ -71,7 +86,11 @@ class MessageSenderTest {
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
@ -82,7 +101,10 @@ class MessageSenderTest {
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId)));
final MessageProtos.Envelope expectedMessage = ephemeral
? message.toBuilder().setEphemeral(true).build()
@ -97,23 +119,61 @@ class MessageSenderTest {
}
}
@Test
void sendMessageMismatchedDevices() {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final MismatchedDevicesException mismatchedDevicesException =
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId + 1)));
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
mismatchedDevicesException.getMismatchedDevices());
}
@CartesianTest
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
@ -125,12 +185,19 @@ class MessageSenderTest {
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
Collections.emptyMap(),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
final SealedSenderMultiRecipientMessage.Recipient recipient =
multiRecipientMessage.getRecipients().values().iterator().next();
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
.join());
if (expectPushNotificationAttempt) {
@ -140,6 +207,49 @@ class MessageSenderTest {
}
}
@Test
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
final SealedSenderMultiRecipientMessage.Recipient recipient =
multiRecipientMessage.getRecipients().values().iterator().next();
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
assertThrows(MultiRecipientMismatchedDevicesException.class,
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
false,
true)
.join());
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
}
@ParameterizedTest
@MethodSource
void getDeliveryChannelName(final Device device, final String expectedChannelName) {
@ -183,4 +293,87 @@ class MessageSenderTest {
assertDoesNotThrow(() ->
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
}
@ParameterizedTest
@MethodSource
void getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
assertEquals(expectedMismatchedDevices,
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
}
private static List<Arguments> getMismatchedDevices() {
final byte primaryDeviceId = Device.PRIMARY_ID;
final byte linkedDeviceId = primaryDeviceId + 1;
final byte extraDeviceId = linkedDeviceId + 1;
final int primaryDeviceAciRegistrationId = 2;
final int primaryDevicePniRegistrationId = 3;
final int linkedDeviceAciRegistrationId = 5;
final int linkedDevicePniRegistrationId = 7;
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
return List.of(
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
account,
aciServiceIdentifier,
Map.of(
primaryDeviceId, primaryDeviceAciRegistrationId,
linkedDeviceId, linkedDeviceAciRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
account,
pniServiceIdentifier,
Map.of(
primaryDeviceId, primaryDevicePniRegistrationId,
linkedDeviceId, linkedDevicePniRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list, device excluded",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId
),
primaryDeviceId,
Optional.empty()),
Arguments.argumentSet("Mismatched devices",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
extraDeviceId, 17
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
);
}
}

View File

@ -60,10 +60,12 @@ import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.spec.SecretKeySpec;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@ -1705,4 +1708,47 @@ class AccountsManagerTest {
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds, @Nullable final MismatchedDevicesException expectedException) {
final Executable validateCompleteDeviceListExecutable =
() -> AccountsManager.validateCompleteDeviceList(account, deviceIds);
if (expectedException != null) {
final MismatchedDevicesException caughtException =
assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable);
assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices());
} else {
assertDoesNotThrow(validateCompleteDeviceListExecutable);
}
}
private static List<Arguments> validateCompleteDeviceList() {
final byte deviceId = Device.PRIMARY_ID;
final byte extraDeviceId = deviceId + 1;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
return List.of(
Arguments.of(account, Set.of(deviceId), null),
Arguments.of(account, Set.of(deviceId, extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))),
Arguments.of(account, Collections.emptySet(),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))),
Arguments.of(account, Set.of(extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
);
}
}

View File

@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -105,7 +106,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessages(eq(account), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
}
@Test
@ -119,7 +120,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any());
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
}
@Test
@ -159,7 +160,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -212,7 +213,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -263,7 +264,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -310,7 +311,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -359,7 +360,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -372,82 +373,6 @@ public class ChangeNumberManagerTest {
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@Test
void changeNumberMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test
void updatePniKeysMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test
void changeNumberMissingData() {
final Account account = mock(Account.class);

View File

@ -1,273 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ExtendWith(DropwizardExtensionsSupport.class)
class DestinationDeviceValidatorTest {
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
final Account account = mock(Account.class);
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
final Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
});
return account;
}
static Stream<Arguments> validateRegistrationIdsSource() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 1492),
Set.of(id1)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
Map.of(id1, 0, id2, 42),
Set.of(id2)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
Map.of(id1, 41, id2, 257),
Set.of(id1, id2))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Map<Byte, Integer> registrationIdsByDeviceId,
Set<Byte> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(
account,
registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey,
Map.Entry::getValue,
false))
.getStaleDevices())
.hasSameElementsAs(expectedStaleDeviceIds);
} else {
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey, Map.Entry::getValue, false);
}
}
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
});
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceList() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true));
return Stream.of(
// Device IDs provided for all enabled devices
arguments(
account,
Set.of(id1, id3),
Set.of(id2),
null,
Collections.emptySet()),
// Device ID provided for disabled device
arguments(
account,
Set.of(id1, id2, id3),
null,
null,
Collections.emptySet()),
// Device ID omitted for enabled device
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
null,
Collections.emptySet()),
// Device ID included for disabled device, omitted for enabled device
arguments(
account,
Set.of(id1, id2),
Set.of(id3),
null,
Collections.emptySet()),
// Device ID omitted for enabled device, included for device in excluded list
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
Set.of(id1),
Set.of(id1)
),
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
arguments(
account,
Set.of(id2),
Set.of(id3),
null,
Set.of(id1)
),
// Device ID included for enabled device, omitted for excluded device
arguments(
account,
Set.of(id3),
Set.of(id2),
null,
Set.of(id1)
)
);
}
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(
Account account,
Set<Byte> deviceIds,
Collection<Byte> expectedMissingDeviceIds,
Collection<Byte> expectedExtraDeviceIds,
Set<Byte> excludedDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
}
}
@Test
void testDuplicateDeviceIds() {
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17));
try {
DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false);
Assertions.fail("duplicate devices should throw StaleDevicesException");
} catch (StaleDevicesException e) {
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID));
}
}
@Test
void testValidatePniRegistrationIds() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
final int aciRegistrationId = 17;
final int pniRegistrationId = 89;
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
false));
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false));
}
}