Internalize destination device list/registration ID checks in `MessageSender`

This commit is contained in:
Jon Chambers 2025-04-07 09:15:39 -04:00 committed by GitHub
parent 1d0e2d29a7
commit c6689ca07a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 675 additions and 755 deletions

View File

@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest; import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest; import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
@ -93,8 +93,8 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password") @ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password")
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) @ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class))) @ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large") @ApiResponse(responseCode = "413", description = "One or more device messages was too large")
@ApiResponse(responseCode = "422", description = "The request did not pass validation") @ApiResponse(responseCode = "422", description = "The request did not pass validation")
@ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class))) @ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class)))
@ -150,16 +150,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount); return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409) if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
.type(MediaType.APPLICATION_JSON_TYPE) throw new WebApplicationException(Response.status(410)
.entity(new MismatchedDevices(e.getMissingDevices(), .type(MediaType.APPLICATION_JSON)
e.getExtraDevices())) .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build()); .build());
} catch (StaleDevicesException e) { } else {
throw new WebApplicationException(Response.status(410) throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON) .type(MediaType.APPLICATION_JSON_TYPE)
.entity(new StaleDevices(e.getStaleDevices())) .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
.build()); e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
throw new BadRequestException(e); throw new BadRequestException(e);
} catch (MessageTooLargeException e) { } catch (MessageTooLargeException e) {
@ -178,9 +180,9 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.") @ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.")
@ApiResponse(responseCode = "422", description = "The request body failed validation.") @ApiResponse(responseCode = "422", description = "The request body failed validation.")
@ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.", @ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.", @ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.",
content = @Content(schema = @Schema(implementation = StaleDevices.class))) content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large") @ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys( public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice, @Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@ -207,16 +209,18 @@ public class AccountControllerV2 {
return AccountIdentityResponseBuilder.fromAccount(updatedAccount); return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409) if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
.type(MediaType.APPLICATION_JSON_TYPE) throw new WebApplicationException(Response.status(410)
.entity(new MismatchedDevices(e.getMissingDevices(), .type(MediaType.APPLICATION_JSON)
e.getExtraDevices())) .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build()); .build());
} catch (StaleDevicesException e) { } else {
throw new WebApplicationException(Response.status(410) throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON) .type(MediaType.APPLICATION_JSON_TYPE)
.entity(new StaleDevices(e.getStaleDevices())) .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
.build()); e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
throw new BadRequestException(e); throw new BadRequestException(e);
} catch (MessageTooLargeException e) { } catch (MessageTooLargeException e) {

View File

@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status; import jakarta.ws.rs.core.Response.Status;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -62,7 +60,6 @@ import javax.annotation.Nullable;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair; import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -245,10 +241,10 @@ public class MessageController {
description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user") description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user")
@ApiResponse( @ApiResponse(
responseCode = "409", description = "Incorrect set of devices supplied for recipient", responseCode = "409", description = "Incorrect set of devices supplied for recipient",
content = @Content(schema = @Schema(implementation = MismatchedDevices.class))) content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
@ApiResponse( @ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = StaleDevices.class))) content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse( @ApiResponse(
responseCode="428", responseCode="428",
description="The sender should complete a challenge before proceeding") description="The sender should complete a challenge before proceeding")
@ -381,14 +377,6 @@ public class MessageController {
rateLimiters.getStoriesLimiter().validate(destination.getUuid()); rateLimiters.getStoriesLimiter().validate(destination.getUuid());
} }
final Set<Byte> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
} else {
excludedDeviceIds = Collections.emptySet();
}
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream() final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> { .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
try { try {
@ -407,15 +395,8 @@ public class MessageController {
} }
})); }));
DestinationDeviceValidator.validateCompleteDeviceList(destination, final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
messagesByDeviceId.keySet(), .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
excludedDeviceIds);
DestinationDeviceValidator.validateRegistrationIds(destination,
messages.messages(),
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final String authType; final String authType;
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) { if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
@ -428,7 +409,7 @@ public class MessageController {
authType = AUTH_TYPE_ACCESS_KEY; authType = AUTH_TYPE_ACCESS_KEY;
} }
messageSender.sendMessages(destination, messagesByDeviceId); messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent), Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE), Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
@ -440,16 +421,18 @@ public class MessageController {
return Response.ok(new SendMessageResponse(needsSync)).build(); return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (final MismatchedDevicesException e) { } catch (final MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409) if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
.type(MediaType.APPLICATION_JSON_TYPE) throw new WebApplicationException(Response.status(410)
.entity(new MismatchedDevices(e.getMissingDevices(), .type(MediaType.APPLICATION_JSON)
e.getExtraDevices())) .entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
.build()); .build());
} catch (final StaleDevicesException e) { } else {
throw new WebApplicationException(Response.status(410) throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON) .type(MediaType.APPLICATION_JSON_TYPE)
.entity(new StaleDevices(e.getStaleDevices())) .entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
.build()); e.getMismatchedDevices().extraDeviceIds()))
.build());
}
} }
} finally { } finally {
sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME) sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME)
@ -622,57 +605,6 @@ public class MessageController {
} }
} }
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
// can just skip them.
return;
}
final Account account = resolvedRecipients.get(recipient);
try {
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
Collections.emptySet());
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceIdsToRegistrationIds.entrySet(),
Map.Entry<Byte, Short>::getKey,
e -> Integer.valueOf(e.getValue()),
serviceId instanceof ServiceId.Pni);
} catch (final MismatchedDevicesException e) {
accountMismatchedDevices.add(
new AccountMismatchedDevices(
ServiceIdentifier.fromLibsignal(serviceId),
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (final StaleDevicesException e) {
accountStaleDevices.add(
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
}
});
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
final String authType; final String authType;
if (isStory) { if (isStory) {
authType = AUTH_TYPE_STORY; authType = AUTH_TYPE_STORY;
@ -731,6 +663,38 @@ public class MessageController {
} catch (ExecutionException e) { } catch (ExecutionException e) {
logger.error("partial failure while delivering multi-recipient messages", e.getCause()); logger.error("partial failure while delivering multi-recipient messages", e.getCause());
throw new InternalServerErrorException("failure during delivery"); throw new InternalServerErrorException("failure during delivery");
} catch (MultiRecipientMismatchedDevicesException e) {
final List<AccountMismatchedDevices> accountMismatchedDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty())
.map(entry -> new AccountMismatchedDevices(entry.getKey(),
new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds())))
.toList();
if (!accountMismatchedDevices.isEmpty()) {
return Response
.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(accountMismatchedDevices)
.build();
}
final List<AccountStaleDevices> accountStaleDevices =
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
.filter(entry -> !entry.getValue().staleDeviceIds().isEmpty())
.map(entry -> new AccountStaleDevices(entry.getKey(),
new StaleDevicesResponse(entry.getValue().staleDeviceIds())))
.toList();
if (!accountStaleDevices.isEmpty()) {
return Response
.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(accountStaleDevices)
.build();
}
throw new RuntimeException(e);
} }
} }

View File

@ -0,0 +1,11 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.Set;
public record MismatchedDevices(Set<Byte> missingDeviceIds, Set<Byte> extraDeviceIds, Set<Byte> staleDeviceIds) {
}

View File

@ -5,23 +5,15 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class MismatchedDevicesException extends Exception { public class MismatchedDevicesException extends Exception {
private final List<Byte> missingDevices; private final MismatchedDevices mismatchedDevices;
private final List<Byte> extraDevices;
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) { public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) {
this.missingDevices = missingDevices; this.mismatchedDevices = mismatchedDevices;
this.extraDevices = extraDevices;
} }
public List<Byte> getMissingDevices() { public MismatchedDevices getMismatchedDevices() {
return missingDevices; return mismatchedDevices;
}
public List<Byte> getExtraDevices() {
return extraDevices;
} }
} }

View File

@ -0,0 +1,24 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
public class MultiRecipientMismatchedDevicesException extends Exception {
private final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier;
public MultiRecipientMismatchedDevicesException(
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier) {
this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier;
}
public Map<ServiceIdentifier, MismatchedDevices> getMismatchedDevicesByServiceIdentifier() {
return mismatchedDevicesByServiceIdentifier;
}
}

View File

@ -1,22 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Exception {
private final List<Byte> staleDevices;
public StaleDevicesException(List<Byte> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Byte> getStaleDevices() {
return staleDevices;
}
}

View File

@ -14,5 +14,5 @@ public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierA
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid, ServiceIdentifier uuid,
MismatchedDevices devices) { MismatchedDevicesResponse devices) {
} }

View File

@ -14,5 +14,5 @@ public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapte
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid, ServiceIdentifier uuid,
StaleDevices devices) { StaleDevicesResponse devices) {
} }

View File

@ -1,20 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
public record MismatchedDevices(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
List<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
List<Byte> extraDevices) {
}

View File

@ -0,0 +1,20 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.Set;
public record MismatchedDevicesResponse(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
Set<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
Set<Byte> extraDevices) {
}

View File

@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List; import java.util.Set;
public record StaleDevices(@JsonProperty public record StaleDevicesResponse(@JsonProperty
@Schema(description = "Devices that are no longer active") @Schema(description = "Devices that are no longer active")
List<Byte> staleDevices) { Set<Byte> staleDevices) {
} }

View File

@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.util.DataSize; import io.dropwizard.util.DataSize;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -58,6 +69,9 @@ public class MessageSender {
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes(); public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes(); private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
@VisibleForTesting
static final byte NO_EXCLUDED_DEVICE_ID = -1;
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) { public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
@ -68,23 +82,51 @@ public class MessageSender {
* notification token and does not have an active connection to a Signal server, then this method will also send a * notification token and does not have an active connection to a Signal server, then this method will also send a
* push notification to that device to announce the availability of new messages. * push notification to that device to announce the availability of new messages.
* *
* @param account the account to which to send messages * @param destination the account to which to send messages
* @param destinationIdentifier the service identifier to which the messages are addressed
* @param messagesByDeviceId a map of device IDs to message payloads * @param messagesByDeviceId a map of device IDs to message payloads
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
*/ */
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) { public void sendMessages(final Account destination,
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId) final ServiceIdentifier destinationIdentifier,
final Map<Byte, Envelope> messagesByDeviceId,
final Map<Byte, Integer> registrationIdsByDeviceId) throws MismatchedDevicesException {
if (messagesByDeviceId.isEmpty()) {
return;
}
if (!destination.isIdentifiedBy(destinationIdentifier)) {
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
}
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
destinationIdentifier,
registrationIdsByDeviceId,
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
if (maybeMismatchedDevices.isPresent()) {
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
}
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
.forEach((deviceId, destinationPresent) -> { .forEach((deviceId, destinationPresent) -> {
final Envelope message = messagesByDeviceId.get(deviceId); final Envelope message = messagesByDeviceId.get(deviceId);
if (!destinationPresent && !message.getEphemeral()) { if (!destinationPresent && !message.getEphemeral()) {
try { try {
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent()); pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
} catch (final NotPushRegisteredException ignored) { } catch (final NotPushRegisteredException ignored) {
} }
} }
Metrics.counter(SEND_COUNTER_NAME, Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"), CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()), EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent), CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()), URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
@ -98,6 +140,10 @@ public class MessageSender {
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not * Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
* have an active connection to a Signal server, then this method will also send a push notification to that device to * have an active connection to a Signal server, then this method will also send a push notification to that device to
* announce the availability of new messages. * announce the availability of new messages.
* <p>
* This method sends messages to all <em>resolved</em> recipients. In some cases, a caller may not be able to resolve
* all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting
* the message if they require full resolution of all recipients, but some recipients could not be resolved.
* *
* @param multiRecipientMessage the multi-recipient message to send to the given recipients * @param multiRecipientMessage the multi-recipient message to send to the given recipients
* @param resolvedRecipients a map of recipients to resolved Signal accounts * @param resolvedRecipients a map of recipients to resolved Signal accounts
@ -114,7 +160,31 @@ public class MessageSender {
final long clientTimestamp, final long clientTimestamp,
final boolean isStory, final boolean isStory,
final boolean isEphemeral, final boolean isEphemeral,
final boolean isUrgent) { final boolean isUrgent) throws MultiRecipientMismatchedDevicesException {
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
if (!resolvedRecipients.containsKey(recipient)) {
// Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run
// into an unresolved recipient here, just skip it.
return;
}
final Account account = resolvedRecipients.get(recipient);
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId);
final Map<Byte, Integer> registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds()
.collect(Collectors.toMap(Pair::first, pair -> (int) pair.second()));
getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID)
.ifPresent(mismatchedDevices ->
mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices));
});
if (!mismatchedDevicesByServiceIdentifier.isEmpty()) {
throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier);
}
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp, return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
isStory, isEphemeral, isUrgent) isStory, isEphemeral, isUrgent)
@ -189,4 +259,47 @@ public class MessageSender {
.increment(); .increment();
} }
} }
@VisibleForTesting
static Optional<MismatchedDevices> getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId) {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != excludedDeviceId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet());
final Set<Byte> extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet());
extraDeviceIds.removeAll(accountDeviceIds);
final Set<Byte> staleDeviceIds = registrationIdsByDeviceId.entrySet().stream()
// Filter out device IDs that aren't associated with the given account
.filter(entry -> !extraDeviceIds.contains(entry.getKey()))
.filter(entry -> {
final byte deviceId = entry.getKey();
final int registrationId = entry.getValue();
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
};
return registrationId != expectedRegistrationId;
})
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty())
? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds))
: Optional.empty();
}
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.Map;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -54,9 +55,20 @@ public class ReceiptSender {
.setUrgent(false) .setUrgent(false)
.build(); .build();
final Map<Byte, Envelope> messagesByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
}));
try { try {
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream() messageSender.sendMessages(destinationAccount,
.collect(Collectors.toMap(Device::getId, ignored -> message))); destinationIdentifier,
messagesByDeviceId,
registrationIdsByDeviceId);
} catch (final Exception e) { } catch (final Exception e) {
logger.warn("Could not send delivery receipt", e); logger.warn("Could not send delivery receipt", e);
} }

View File

@ -37,10 +37,12 @@ import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -55,6 +57,7 @@ import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.crypto.Mac; import javax.crypto.Mac;
@ -67,6 +70,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@ -83,7 +87,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -780,24 +783,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} }
// Check that all including primary ID are in signed pre-keys // Check that all including primary ID are in signed pre-keys
DestinationDeviceValidator.validateCompleteDeviceList( validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all including primary ID are in Pq pre-keys // Check that all including primary ID are in Pq pre-keys
if (pniPqLastResortPreKeys != null) { if (pniPqLastResortPreKeys != null) {
DestinationDeviceValidator.validateCompleteDeviceList( validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
account,
pniPqLastResortPreKeys.keySet(),
Collections.emptySet());
} }
// Check that all devices are accounted for in the map of new PNI registration IDs // Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList( validateCompleteDeviceList(account, pniRegistrationIds.keySet());
account, }
pniRegistrationIds.keySet(),
Collections.emptySet()); @VisibleForTesting
static void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(deviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(deviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
}
} }
public record UsernameReservation(Account account, byte[] reservedUsernameHash){} public record UsernameReservation(Account account, byte[] reservedUsernameHash){}

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
@ -15,15 +14,15 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
public class ChangeNumberManager { public class ChangeNumberManager {
@ -45,12 +44,10 @@ public class ChangeNumberManager {
@Nullable final List<IncomingMessage> deviceMessages, @Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Byte, Integer> pniRegistrationIds, @Nullable final Map<Byte, Integer> pniRegistrationIds,
@Nullable final String senderUserAgent) @Nullable final String senderUserAgent)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException, MessageTooLargeException { throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
// AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds))) {
validateDeviceMessages(account, deviceMessages);
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null"); throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
} }
@ -84,9 +81,7 @@ public class ChangeNumberManager {
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys, @Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages, final List<IncomingMessage> deviceMessages,
final Map<Byte, Integer> pniRegistrationIds, final Map<Byte, Integer> pniRegistrationIds,
final String senderUserAgent) throws MismatchedDevicesException, StaleDevicesException, MessageTooLargeException { final String senderUserAgent) throws MismatchedDevicesException, MessageTooLargeException {
validateDeviceMessages(account, deviceMessages);
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations. // write anyway. Linked devices can handle some wasted extra key rotations.
@ -97,26 +92,9 @@ public class ChangeNumberManager {
return updatedAccount; return updatedAccount;
} }
private void validateDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages) throws MismatchedDevicesException, StaleDevicesException {
// Check that all except primary ID are in device messages
DestinationDeviceValidator.validateCompleteDeviceList(
account,
deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()),
Set.of(Device.PRIMARY_ID));
// check that all sync messages are to the current registration ID for the matching device
DestinationDeviceValidator.validateRegistrationIds(
account,
deviceMessages,
IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId,
false);
}
private void sendDeviceMessages(final Account account, private void sendDeviceMessages(final Account account,
final List<IncomingMessage> deviceMessages, final List<IncomingMessage> deviceMessages,
final String senderUserAgent) throws MessageTooLargeException { final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
for (final IncomingMessage message : deviceMessages) { for (final IncomingMessage message : deviceMessages) {
MessageSender.validateContentLength(message.content().length, MessageSender.validateContentLength(message.content().length,
@ -128,20 +106,26 @@ public class ChangeNumberManager {
try { try {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = System.currentTimeMillis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
messageSender.sendMessages(account, deviceMessages.stream() final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder() .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
.setType(Envelope.Type.forNumber(message.type())) .setType(Envelope.Type.forNumber(message.type()))
.setClientTimestamp(serverTimestamp) .setClientTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) .setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
.setContent(ByteString.copyFrom(message.content())) .setContent(ByteString.copyFrom(message.content()))
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString()) .setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID) .setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(account.getPhoneNumberIdentifier().toString()) .setUpdatedPni(account.getPhoneNumberIdentifier().toString())
.setUrgent(true) .setUrgent(true)
.setEphemeral(false) .setEphemeral(false)
.build()))); .build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
} catch (final RuntimeException e) { } catch (final RuntimeException e) {
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e); logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
throw e; throw e;

View File

@ -1,108 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DestinationDeviceValidator {
/**
* @see #validateRegistrationIds(Account, Stream, boolean)
*/
public static <T> void validateRegistrationIds(final Account account,
final Collection<T> messages,
Function<T, Byte> getDeviceId,
Function<T, Integer> getRegistrationId,
boolean usePhoneNumberIdentity) throws StaleDevicesException {
validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
usePhoneNumberIdentity);
}
/**
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
* destination account are present in the given device ID/registration ID pairs.
*
* @param account the destination account against which to check the given device
* ID/registration ID pairs
* @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
* registration IDs associated with the account's PNI (if available); compare
* against the ACI-associated registration ID otherwise
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Stream<Pair<Byte, Integer>> deviceIdAndRegistrationIdStream,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Byte> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
final byte deviceId = deviceIdAndRegistrationId.first();
final int registrationId = deviceIdAndRegistrationId.second();
boolean registrationIdMatches = account.getDevice(deviceId)
.map(device -> registrationId == (usePhoneNumberIdentity
? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
: device.getRegistrationId()))
.orElse(false);
return !registrationIdMatches;
})
.map(Pair::first)
.collect(Collectors.toList());
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
/**
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
* given destination account in preparation for sending those messages to the destination account. In general, the set
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
*
* @param account the destination account against which to check the given set of device IDs
* @param messageDeviceIds the set of device IDs to check against the destination account
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
* present in the given set of device IDs (i.e. the device that is sending a sync message)
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
* the destination account or is missing entries associated with the destination
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Byte> messageDeviceIds,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
}
}
}

View File

@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import javax.annotation.Nullable;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest { class MessageControllerTest {
@ -195,7 +197,7 @@ class MessageControllerTest {
.build(); .build();
@BeforeEach @BeforeEach
void setup() { void setup() throws MultiRecipientMismatchedDevicesException {
reset(pushNotificationScheduler); reset(pushNotificationScheduler);
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
@ -287,7 +289,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -332,7 +334,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -357,7 +359,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -395,7 +397,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(200))); assertThat("Good Response", response.getStatus(), is(equalTo(200)));
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -434,7 +436,7 @@ class MessageControllerTest {
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse))); assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
if (expectedResponse == 200) { if (expectedResponse == 200) {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class); @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), captor.capture()); verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
assertEquals(1, captor.getValue().size()); assertEquals(1, captor.getValue().size());
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow(); final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
@ -530,6 +532,9 @@ class MessageControllerTest {
@Test @Test
void testMultiDeviceMissing() throws Exception { void testMultiDeviceMissing() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -542,15 +547,16 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body", assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)), asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response.json")))); is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
verifyNoMoreInteractions(messageSender);
} }
} }
@Test @Test
void testMultiDeviceExtra() throws Exception { void testMultiDeviceExtra() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) .target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
@ -563,10 +569,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(409))); assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
assertThat("Good Response Body", assertThat("Good Response Body",
asJson(response.readEntity(MismatchedDevices.class)), asJson(response.readEntity(MismatchedDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/missing_device_response2.json")))); is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
verifyNoMoreInteractions(messageSender);
} }
} }
@ -602,7 +606,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -626,7 +630,7 @@ class MessageControllerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
assertEquals(3, envelopeCaptor.getValue().size()); assertEquals(3, envelopeCaptor.getValue().size());
@ -648,12 +652,17 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(200))); assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
verify(messageSender).sendMessages(any(Account.class), verify(messageSender).sendMessages(any(Account.class),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3)); any(),
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
any());
} }
} }
@Test @Test
void testRegistrationIdMismatch() throws Exception { void testRegistrationIdMismatch() throws Exception {
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
.when(messageSender).sendMessages(any(), any(), any(), any());
try (final Response response = try (final Response response =
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID)) resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
.request() .request()
@ -665,10 +674,8 @@ class MessageControllerTest {
assertThat("Good Response Code", response.getStatus(), is(equalTo(410))); assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
assertThat("Good Response Body", assertThat("Good Response Body",
asJson(response.readEntity(StaleDevices.class)), asJson(response.readEntity(StaleDevicesResponse.class)),
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json")))); is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
verifyNoMoreInteractions(messageSender);
} }
} }
@ -1078,7 +1085,7 @@ class MessageControllerTest {
} }
@Test @Test
void testValidateContentLength() { void testValidateContentLength() throws MismatchedDevicesException {
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1); final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
final byte[] contentBytes = new byte[contentLength]; final byte[] contentBytes = new byte[contentLength];
Arrays.fill(contentBytes, (byte) 1); Arrays.fill(contentBytes, (byte) 1);
@ -1095,7 +1102,7 @@ class MessageControllerTest {
assertThat("Bad response", response.getStatus(), is(equalTo(413))); assertThat("Bad response", response.getStatus(), is(equalTo(413)));
verify(messageSender, never()).sendMessages(any(), any()); verify(messageSender, never()).sendMessages(any(), any(), any(), any());
} }
} }
@ -1113,10 +1120,10 @@ class MessageControllerTest {
if (expectOk) { if (expectOk) {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
verify(messageSender).sendMessages(any(), any()); verify(messageSender).sendMessages(any(), any(), any(), any());
} else { } else {
assertEquals(422, response.getStatus()); assertEquals(422, response.getStatus());
verify(messageSender, never()).sendMessages(any(), any()); verify(messageSender, never()).sendMessages(any(), any(), any(), any());
} }
} }
} }
@ -1140,7 +1147,9 @@ class MessageControllerTest {
final Optional<String> maybeGroupSendToken, final Optional<String> maybeGroupSendToken,
final int expectedStatus, final int expectedStatus,
final Set<Account> expectedResolvedAccounts, final Set<Account> expectedResolvedAccounts,
final Set<ServiceIdentifier> expectedUuids404) { final Set<ServiceIdentifier> expectedUuids404,
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
throws MultiRecipientMismatchedDevicesException {
clock.pin(START_OF_DAY); clock.pin(START_OF_DAY);
@ -1151,6 +1160,11 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
if (mismatchedDevicesException != null) {
doThrow(mismatchedDevicesException)
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
}
final boolean ephemeral = true; final boolean ephemeral = true;
final boolean urgent = false; final boolean urgent = false;
@ -1187,7 +1201,7 @@ class MessageControllerTest {
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404)); assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
} }
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) { if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
verify(messageSender).sendMultiRecipientMessage(any(), verify(messageSender).sendMultiRecipientMessage(any(),
argThat(resolvedRecipients -> argThat(resolvedRecipients ->
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)), new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
@ -1267,7 +1281,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with combined UAKs", Arguments.argumentSet("Multi-recipient message with combined UAKs",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1279,7 +1294,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Multi-recipient message with group send endorsement", Arguments.argumentSet("Multi-recipient message with group send endorsement",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1291,7 +1307,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Incorrect combined UAK", Arguments.argumentSet("Incorrect combined UAK",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1303,7 +1320,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
401, 401,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Incorrect group send endorsement", Arguments.argumentSet("Incorrect group send endorsement",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1317,7 +1335,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))), START_OF_DAY.plus(Duration.ofDays(1)))),
401, 401,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if // Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
// provided // provided
@ -1331,7 +1350,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Story with group send endorsement", Arguments.argumentSet("Story with group send endorsement",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1343,7 +1363,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Conflicting credentials", Arguments.argumentSet("Conflicting credentials",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1355,7 +1376,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("No credentials", Arguments.argumentSet("No credentials",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1367,7 +1389,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
401, 401,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Oversized payload", Arguments.argumentSet("Oversized payload",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1383,7 +1406,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
413, 413,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Negative timestamp", Arguments.argumentSet("Negative timestamp",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1395,7 +1419,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Excessive timestamp", Arguments.argumentSet("Excessive timestamp",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1407,7 +1432,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Empty recipient list", Arguments.argumentSet("Empty recipient list",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1421,7 +1447,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))), START_OF_DAY.plus(Duration.ofDays(1)))),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Story with empty recipient list", Arguments.argumentSet("Story with empty recipient list",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1433,7 +1460,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Duplicate recipient", Arguments.argumentSet("Duplicate recipient",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1447,7 +1475,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
400, 400,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Missing account", Arguments.argumentSet("Missing account",
Map.of(), Map.of(),
@ -1459,7 +1488,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
200, 200,
Collections.emptySet(), Collections.emptySet(),
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))), Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("One missing and one existing account", Arguments.argumentSet("One missing and one existing account",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount), Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@ -1471,7 +1501,8 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
200, 200,
Set.of(singleDeviceAccount), Set.of(singleDeviceAccount),
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))), Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
null),
Arguments.argumentSet("Missing account for story", Arguments.argumentSet("Missing account for story",
Map.of(), Map.of(),
@ -1483,7 +1514,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Collections.emptySet(), Collections.emptySet(),
Set.of()), Set.of(),
null),
Arguments.argumentSet("One missing and one existing account for story", Arguments.argumentSet("One missing and one existing account for story",
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount), Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
@ -1495,7 +1527,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Set.of(singleDeviceAccount), Set.of(singleDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Missing device", Arguments.argumentSet("Missing device",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1509,7 +1542,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
409, 409,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
Arguments.argumentSet("Extra device", Arguments.argumentSet("Extra device",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1525,7 +1560,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
409, 409,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
Arguments.argumentSet("Stale registration ID", Arguments.argumentSet("Stale registration ID",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1540,7 +1577,9 @@ class MessageControllerTest {
Optional.of(groupSendEndorsement), Optional.of(groupSendEndorsement),
410, 410,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
Arguments.argumentSet("Rate-limited story", Arguments.argumentSet("Rate-limited story",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1552,7 +1591,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
429, 429,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Story to PNI recipients", Arguments.argumentSet("Story to PNI recipients",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1567,7 +1607,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK", Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1582,7 +1623,8 @@ class MessageControllerTest {
Optional.empty(), Optional.empty(),
401, 401,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()), Set.of(),
null),
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement", Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
accountsByServiceIdentifier, accountsByServiceIdentifier,
@ -1599,7 +1641,8 @@ class MessageControllerTest {
START_OF_DAY.plus(Duration.ofDays(1)))), START_OF_DAY.plus(Duration.ofDays(1)))),
200, 200,
Set.of(singleDeviceAccount, multiDeviceAccount), Set.of(singleDeviceAccount, multiDeviceAccount),
Set.of()) Set.of(),
null)
); );
} }

View File

@ -22,6 +22,9 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.InvalidVersionException;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
class MessageSenderTest { class MessageSenderTest {
@ -60,7 +73,9 @@ class MessageSenderTest {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID; final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class); final Account account = mock(Account.class);
final Device device = mock(Device.class); final Device device = mock(Device.class);
@ -71,7 +86,11 @@ class MessageSenderTest {
when(account.getUuid()).thenReturn(accountIdentifier); when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
if (hasPushToken) { if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token"); when(device.getApnId()).thenReturn("apns-token");
@ -82,7 +101,10 @@ class MessageSenderTest {
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent)); when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message))); assertDoesNotThrow(() -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId)));
final MessageProtos.Envelope expectedMessage = ephemeral final MessageProtos.Envelope expectedMessage = ephemeral
? message.toBuilder().setEphemeral(true).build() ? message.toBuilder().setEphemeral(true).build()
@ -97,23 +119,61 @@ class MessageSenderTest {
} }
} }
@Test
void sendMessageMismatchedDevices() {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final MismatchedDevicesException mismatchedDevicesException =
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
serviceIdentifier,
Map.of(device.getId(), message),
Map.of(device.getId(), registrationId + 1)));
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
mismatchedDevicesException.getMismatchedDevices());
}
@CartesianTest @CartesianTest
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent, void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral, @CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
@CartesianTest.Values(booleans = {true, false}) final boolean urgent, @CartesianTest.Values(booleans = {true, false}) final boolean urgent,
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException { @CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral; final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID; final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class); final Account account = mock(Account.class);
final Device device = mock(Device.class); final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier); when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
if (hasPushToken) { if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token"); when(device.getApnId()).thenReturn("apns-token");
@ -125,12 +185,19 @@ class MessageSenderTest {
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean())) when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent)))); .thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class), final SealedSenderMultiRecipientMessage multiRecipientMessage =
Collections.emptyMap(), SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
System.currentTimeMillis(), List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
false,
ephemeral, final SealedSenderMultiRecipientMessage.Recipient recipient =
urgent) multiRecipientMessage.getRecipients().values().iterator().next();
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
ephemeral,
urgent)
.join()); .join());
if (expectPushNotificationAttempt) { if (expectPushNotificationAttempt) {
@ -140,6 +207,49 @@ class MessageSenderTest {
} }
} }
@Test
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
final UUID accountIdentifier = UUID.randomUUID();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
final byte deviceId = Device.PRIMARY_ID;
final int registrationId = 17;
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(accountIdentifier);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final SealedSenderMultiRecipientMessage multiRecipientMessage =
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
final SealedSenderMultiRecipientMessage.Recipient recipient =
multiRecipientMessage.getRecipients().values().iterator().next();
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
assertThrows(MultiRecipientMismatchedDevicesException.class,
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
Map.of(recipient, account),
System.currentTimeMillis(),
false,
false,
true)
.join());
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void getDeliveryChannelName(final Device device, final String expectedChannelName) { void getDeliveryChannelName(final Device device, final String expectedChannelName) {
@ -183,4 +293,87 @@ class MessageSenderTest {
assertDoesNotThrow(() -> assertDoesNotThrow(() ->
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null)); MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
} }
@ParameterizedTest
@MethodSource
void getMismatchedDevices(final Account account,
final ServiceIdentifier serviceIdentifier,
final Map<Byte, Integer> registrationIdsByDeviceId,
final byte excludedDeviceId,
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
assertEquals(expectedMismatchedDevices,
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
}
private static List<Arguments> getMismatchedDevices() {
final byte primaryDeviceId = Device.PRIMARY_ID;
final byte linkedDeviceId = primaryDeviceId + 1;
final byte extraDeviceId = linkedDeviceId + 1;
final int primaryDeviceAciRegistrationId = 2;
final int primaryDevicePniRegistrationId = 3;
final int linkedDeviceAciRegistrationId = 5;
final int linkedDevicePniRegistrationId = 7;
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
return List.of(
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
account,
aciServiceIdentifier,
Map.of(
primaryDeviceId, primaryDeviceAciRegistrationId,
linkedDeviceId, linkedDeviceAciRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
account,
pniServiceIdentifier,
Map.of(
primaryDeviceId, primaryDevicePniRegistrationId,
linkedDeviceId, linkedDevicePniRegistrationId
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.empty()),
Arguments.argumentSet("Complete device list, device excluded",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId
),
primaryDeviceId,
Optional.empty()),
Arguments.argumentSet("Mismatched devices",
account,
aciServiceIdentifier,
Map.of(
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
extraDeviceId, 17
),
MessageSender.NO_EXCLUDED_DEVICE_ID,
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
);
}
} }

View File

@ -60,10 +60,12 @@ import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.SecretKeySpec;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
@ -1705,4 +1708,47 @@ class AccountsManagerTest {
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp) Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
); );
} }
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds, @Nullable final MismatchedDevicesException expectedException) {
final Executable validateCompleteDeviceListExecutable =
() -> AccountsManager.validateCompleteDeviceList(account, deviceIds);
if (expectedException != null) {
final MismatchedDevicesException caughtException =
assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable);
assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices());
} else {
assertDoesNotThrow(validateCompleteDeviceListExecutable);
}
}
private static List<Arguments> validateCompleteDeviceList() {
final byte deviceId = Device.PRIMARY_ID;
final byte extraDeviceId = deviceId + 1;
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
return List.of(
Arguments.of(account, Set.of(deviceId), null),
Arguments.of(account, Set.of(deviceId, extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))),
Arguments.of(account, Collections.emptySet(),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))),
Arguments.of(account, Set.of(extraDeviceId),
new MismatchedDevicesException(
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
);
}
} }

View File

@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -105,7 +106,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessages(eq(account), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
} }
@Test @Test
@ -119,7 +120,7 @@ public class ChangeNumberManagerTest {
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessages(eq(account), any()); verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
} }
@Test @Test
@ -159,7 +160,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -212,7 +213,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -263,7 +264,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -310,7 +311,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -359,7 +360,7 @@ public class ChangeNumberManagerTest {
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = @SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class); ArgumentCaptor.forClass(Map.class);
verify(messageSender).sendMessages(any(), envelopeCaptor.capture()); verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
assertEquals(1, envelopeCaptor.getValue().size()); assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
@ -372,82 +373,6 @@ public class ChangeNumberManagerTest {
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
} }
@Test
void changeNumberMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test
void updatePniKeysMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
final List<Device> devices = new ArrayList<>();
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
}
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
}
@Test @Test
void changeNumberMissingData() { void changeNumberMissingData() {
final Account account = mock(Account.class); final Account account = mock(Account.class);

View File

@ -1,273 +0,0 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ExtendWith(DropwizardExtensionsSupport.class)
class DestinationDeviceValidatorTest {
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
final Account account = mock(Account.class);
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
final Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
});
return account;
}
static Stream<Arguments> validateRegistrationIdsSource() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 1492),
Set.of(id1)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
Map.of(id1, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
Map.of(id1, 0, id2, 42),
Set.of(id2)),
arguments(
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
Map.of(id1, 41, id2, 257),
Set.of(id1, id2))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Map<Byte, Integer> registrationIdsByDeviceId,
Set<Byte> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(
account,
registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey,
Map.Entry::getValue,
false))
.getStaleDevices())
.hasSameElementsAs(expectedStaleDeviceIds);
} else {
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
Map.Entry::getKey, Map.Entry::getValue, false);
}
}
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
});
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceList() {
final byte id1 = 1;
final byte id2 = 2;
final byte id3 = 3;
final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true));
return Stream.of(
// Device IDs provided for all enabled devices
arguments(
account,
Set.of(id1, id3),
Set.of(id2),
null,
Collections.emptySet()),
// Device ID provided for disabled device
arguments(
account,
Set.of(id1, id2, id3),
null,
null,
Collections.emptySet()),
// Device ID omitted for enabled device
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
null,
Collections.emptySet()),
// Device ID included for disabled device, omitted for enabled device
arguments(
account,
Set.of(id1, id2),
Set.of(id3),
null,
Collections.emptySet()),
// Device ID omitted for enabled device, included for device in excluded list
arguments(
account,
Set.of(id1),
Set.of(id2, id3),
Set.of(id1),
Set.of(id1)
),
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
arguments(
account,
Set.of(id2),
Set.of(id3),
null,
Set.of(id1)
),
// Device ID included for enabled device, omitted for excluded device
arguments(
account,
Set.of(id3),
Set.of(id2),
null,
Set.of(id1)
)
);
}
@ParameterizedTest
@MethodSource
void validateCompleteDeviceList(
Account account,
Set<Byte> deviceIds,
Collection<Byte> expectedMissingDeviceIds,
Collection<Byte> expectedExtraDeviceIds,
Set<Byte> excludedDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
}
}
@Test
void testDuplicateDeviceIds() {
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17));
try {
DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false);
Assertions.fail("duplicate devices should throw StaleDevicesException");
} catch (StaleDevicesException e) {
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID));
}
}
@Test
void testValidatePniRegistrationIds() {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
final int aciRegistrationId = 17;
final int pniRegistrationId = 89;
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class,
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
false));
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
false));
assertDoesNotThrow(
() -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true));
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false));
}
}