Internalize destination device list/registration ID checks in `MessageSender`
This commit is contained in:
parent
1d0e2d29a7
commit
c6689ca07a
|
@ -43,12 +43,12 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager
|
|||
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.PhoneNumberDiscoverabilityRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.PhoneNumberIdentityKeyDistributionRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
|
||||
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
|
||||
|
@ -93,8 +93,8 @@ public class AccountControllerV2 {
|
|||
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
|
||||
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
|
||||
@ApiResponse(responseCode = "403", description = "Verification failed for the provided Registration Recovery Password")
|
||||
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
|
||||
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevices.class)))
|
||||
@ApiResponse(responseCode = "409", description = "Mismatched number of devices or device ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
|
||||
@ApiResponse(responseCode = "410", description = "Mismatched registration ids in 'devices to notify' list", content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
|
||||
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
|
||||
@ApiResponse(responseCode = "422", description = "The request did not pass validation")
|
||||
@ApiResponse(responseCode = "423", content = @Content(schema = @Schema(implementation = RegistrationLockFailure.class)))
|
||||
|
@ -150,16 +150,18 @@ public class AccountControllerV2 {
|
|||
|
||||
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
|
||||
} catch (MismatchedDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevices(e.getMissingDevices(),
|
||||
e.getExtraDevices()))
|
||||
.build());
|
||||
} catch (StaleDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevices(e.getStaleDevices()))
|
||||
.build());
|
||||
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
|
||||
.build());
|
||||
} else {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
|
||||
e.getMismatchedDevices().extraDeviceIds()))
|
||||
.build());
|
||||
}
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new BadRequestException(e);
|
||||
} catch (MessageTooLargeException e) {
|
||||
|
@ -178,9 +180,9 @@ public class AccountControllerV2 {
|
|||
@ApiResponse(responseCode = "403", description = "This endpoint can only be invoked from the account's primary device.")
|
||||
@ApiResponse(responseCode = "422", description = "The request body failed validation.")
|
||||
@ApiResponse(responseCode = "409", description = "The set of devices specified in the request does not match the set of devices active on the account.",
|
||||
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
|
||||
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
|
||||
@ApiResponse(responseCode = "410", description = "The registration IDs provided for some devices do not match those stored on the server.",
|
||||
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
|
||||
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
|
||||
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
|
||||
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
|
||||
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
|
||||
|
@ -207,16 +209,18 @@ public class AccountControllerV2 {
|
|||
|
||||
return AccountIdentityResponseBuilder.fromAccount(updatedAccount);
|
||||
} catch (MismatchedDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevices(e.getMissingDevices(),
|
||||
e.getExtraDevices()))
|
||||
.build());
|
||||
} catch (StaleDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevices(e.getStaleDevices()))
|
||||
.build());
|
||||
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
|
||||
.build());
|
||||
} else {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
|
||||
e.getMismatchedDevices().extraDeviceIds()))
|
||||
.build());
|
||||
}
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new BadRequestException(e);
|
||||
} catch (MessageTooLargeException e) {
|
||||
|
|
|
@ -44,14 +44,12 @@ import jakarta.ws.rs.core.Response;
|
|||
import jakarta.ws.rs.core.Response.Status;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CancellationException;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
@ -62,7 +60,6 @@ import javax.annotation.Nullable;
|
|||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.ServiceId;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
import org.signal.libsignal.zkgroup.ServerSecretParams;
|
||||
import org.signal.libsignal.zkgroup.VerificationFailedException;
|
||||
import org.signal.libsignal.zkgroup.groupsend.GroupSendDerivedKeyPair;
|
||||
|
@ -81,13 +78,13 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
|||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.SpamReport;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
|
@ -113,7 +110,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
|
|||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
|
||||
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
||||
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
@ -245,10 +241,10 @@ public class MessageController {
|
|||
description="The message is not a story and some the recipient service ID does not correspond to a registered Signal user")
|
||||
@ApiResponse(
|
||||
responseCode = "409", description = "Incorrect set of devices supplied for recipient",
|
||||
content = @Content(schema = @Schema(implementation = MismatchedDevices.class)))
|
||||
content = @Content(schema = @Schema(implementation = MismatchedDevicesResponse.class)))
|
||||
@ApiResponse(
|
||||
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
|
||||
content = @Content(schema = @Schema(implementation = StaleDevices.class)))
|
||||
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
|
||||
@ApiResponse(
|
||||
responseCode="428",
|
||||
description="The sender should complete a challenge before proceeding")
|
||||
|
@ -381,14 +377,6 @@ public class MessageController {
|
|||
rateLimiters.getStoriesLimiter().validate(destination.getUuid());
|
||||
}
|
||||
|
||||
final Set<Byte> excludedDeviceIds;
|
||||
|
||||
if (isSyncMessage) {
|
||||
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
|
||||
} else {
|
||||
excludedDeviceIds = Collections.emptySet();
|
||||
}
|
||||
|
||||
final Map<Byte, Envelope> messagesByDeviceId = messages.messages().stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> {
|
||||
try {
|
||||
|
@ -407,15 +395,8 @@ public class MessageController {
|
|||
}
|
||||
}));
|
||||
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(destination,
|
||||
messagesByDeviceId.keySet(),
|
||||
excludedDeviceIds);
|
||||
|
||||
DestinationDeviceValidator.validateRegistrationIds(destination,
|
||||
messages.messages(),
|
||||
IncomingMessage::destinationDeviceId,
|
||||
IncomingMessage::destinationRegistrationId,
|
||||
destination.getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = messages.messages().stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
|
||||
|
||||
final String authType;
|
||||
if (SENDER_TYPE_IDENTIFIED.equals(senderType)) {
|
||||
|
@ -428,7 +409,7 @@ public class MessageController {
|
|||
authType = AUTH_TYPE_ACCESS_KEY;
|
||||
}
|
||||
|
||||
messageSender.sendMessages(destination, messagesByDeviceId);
|
||||
messageSender.sendMessages(destination, destinationIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
|
||||
|
||||
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, List.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||
Tag.of(ENDPOINT_TYPE_TAG_NAME, ENDPOINT_TYPE_SINGLE),
|
||||
|
@ -440,16 +421,18 @@ public class MessageController {
|
|||
|
||||
return Response.ok(new SendMessageResponse(needsSync)).build();
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevices(e.getMissingDevices(),
|
||||
e.getExtraDevices()))
|
||||
.build());
|
||||
} catch (final StaleDevicesException e) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevices(e.getStaleDevices()))
|
||||
.build());
|
||||
if (!e.getMismatchedDevices().staleDeviceIds().isEmpty()) {
|
||||
throw new WebApplicationException(Response.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(new StaleDevicesResponse(e.getMismatchedDevices().staleDeviceIds()))
|
||||
.build());
|
||||
} else {
|
||||
throw new WebApplicationException(Response.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(new MismatchedDevicesResponse(e.getMismatchedDevices().missingDeviceIds(),
|
||||
e.getMismatchedDevices().extraDeviceIds()))
|
||||
.build());
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
sample.stop(Timer.builder(SEND_MESSAGE_LATENCY_TIMER_NAME)
|
||||
|
@ -622,57 +605,6 @@ public class MessageController {
|
|||
}
|
||||
}
|
||||
|
||||
final Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
|
||||
final Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
|
||||
|
||||
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
|
||||
if (!resolvedRecipients.containsKey(recipient)) {
|
||||
// When sending stories, we might not be able to resolve all recipients to existing accounts. That's okay! We
|
||||
// can just skip them.
|
||||
return;
|
||||
}
|
||||
|
||||
final Account account = resolvedRecipients.get(recipient);
|
||||
|
||||
try {
|
||||
final Map<Byte, Short> deviceIdsToRegistrationIds = recipient.getDevicesAndRegistrationIds()
|
||||
.collect(Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second));
|
||||
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIdsToRegistrationIds.keySet(),
|
||||
Collections.emptySet());
|
||||
|
||||
DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
deviceIdsToRegistrationIds.entrySet(),
|
||||
Map.Entry<Byte, Short>::getKey,
|
||||
e -> Integer.valueOf(e.getValue()),
|
||||
serviceId instanceof ServiceId.Pni);
|
||||
} catch (final MismatchedDevicesException e) {
|
||||
accountMismatchedDevices.add(
|
||||
new AccountMismatchedDevices(
|
||||
ServiceIdentifier.fromLibsignal(serviceId),
|
||||
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
|
||||
} catch (final StaleDevicesException e) {
|
||||
accountStaleDevices.add(
|
||||
new AccountStaleDevices(ServiceIdentifier.fromLibsignal(serviceId), new StaleDevices(e.getStaleDevices())));
|
||||
}
|
||||
});
|
||||
|
||||
if (!accountMismatchedDevices.isEmpty()) {
|
||||
return Response
|
||||
.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(accountMismatchedDevices)
|
||||
.build();
|
||||
}
|
||||
if (!accountStaleDevices.isEmpty()) {
|
||||
return Response
|
||||
.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(accountStaleDevices)
|
||||
.build();
|
||||
}
|
||||
|
||||
final String authType;
|
||||
if (isStory) {
|
||||
authType = AUTH_TYPE_STORY;
|
||||
|
@ -731,6 +663,38 @@ public class MessageController {
|
|||
} catch (ExecutionException e) {
|
||||
logger.error("partial failure while delivering multi-recipient messages", e.getCause());
|
||||
throw new InternalServerErrorException("failure during delivery");
|
||||
} catch (MultiRecipientMismatchedDevicesException e) {
|
||||
final List<AccountMismatchedDevices> accountMismatchedDevices =
|
||||
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
|
||||
.filter(entry -> !entry.getValue().missingDeviceIds().isEmpty() || !entry.getValue().extraDeviceIds().isEmpty())
|
||||
.map(entry -> new AccountMismatchedDevices(entry.getKey(),
|
||||
new MismatchedDevicesResponse(entry.getValue().missingDeviceIds(), entry.getValue().extraDeviceIds())))
|
||||
.toList();
|
||||
|
||||
if (!accountMismatchedDevices.isEmpty()) {
|
||||
return Response
|
||||
.status(409)
|
||||
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||
.entity(accountMismatchedDevices)
|
||||
.build();
|
||||
}
|
||||
|
||||
final List<AccountStaleDevices> accountStaleDevices =
|
||||
e.getMismatchedDevicesByServiceIdentifier().entrySet().stream()
|
||||
.filter(entry -> !entry.getValue().staleDeviceIds().isEmpty())
|
||||
.map(entry -> new AccountStaleDevices(entry.getKey(),
|
||||
new StaleDevicesResponse(entry.getValue().staleDeviceIds())))
|
||||
.toList();
|
||||
|
||||
if (!accountStaleDevices.isEmpty()) {
|
||||
return Response
|
||||
.status(410)
|
||||
.type(MediaType.APPLICATION_JSON)
|
||||
.entity(accountStaleDevices)
|
||||
.build();
|
||||
}
|
||||
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public record MismatchedDevices(Set<Byte> missingDeviceIds, Set<Byte> extraDeviceIds, Set<Byte> staleDeviceIds) {
|
||||
}
|
|
@ -5,23 +5,15 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class MismatchedDevicesException extends Exception {
|
||||
|
||||
private final List<Byte> missingDevices;
|
||||
private final List<Byte> extraDevices;
|
||||
private final MismatchedDevices mismatchedDevices;
|
||||
|
||||
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) {
|
||||
this.missingDevices = missingDevices;
|
||||
this.extraDevices = extraDevices;
|
||||
public MismatchedDevicesException(final MismatchedDevices mismatchedDevices) {
|
||||
this.mismatchedDevices = mismatchedDevices;
|
||||
}
|
||||
|
||||
public List<Byte> getMissingDevices() {
|
||||
return missingDevices;
|
||||
}
|
||||
|
||||
public List<Byte> getExtraDevices() {
|
||||
return extraDevices;
|
||||
public MismatchedDevices getMismatchedDevices() {
|
||||
return mismatchedDevices;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Copyright 2025 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import java.util.Map;
|
||||
|
||||
public class MultiRecipientMismatchedDevicesException extends Exception {
|
||||
|
||||
private final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier;
|
||||
|
||||
public MultiRecipientMismatchedDevicesException(
|
||||
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier) {
|
||||
|
||||
this.mismatchedDevicesByServiceIdentifier = mismatchedDevicesByServiceIdentifier;
|
||||
}
|
||||
|
||||
public Map<ServiceIdentifier, MismatchedDevices> getMismatchedDevicesByServiceIdentifier() {
|
||||
return mismatchedDevicesByServiceIdentifier;
|
||||
}
|
||||
}
|
|
@ -1,22 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.controllers;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public class StaleDevicesException extends Exception {
|
||||
|
||||
private final List<Byte> staleDevices;
|
||||
|
||||
public StaleDevicesException(List<Byte> staleDevices) {
|
||||
this.staleDevices = staleDevices;
|
||||
}
|
||||
|
||||
public List<Byte> getStaleDevices() {
|
||||
return staleDevices;
|
||||
}
|
||||
}
|
|
@ -14,5 +14,5 @@ public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierA
|
|||
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
|
||||
ServiceIdentifier uuid,
|
||||
|
||||
MismatchedDevices devices) {
|
||||
MismatchedDevicesResponse devices) {
|
||||
}
|
||||
|
|
|
@ -14,5 +14,5 @@ public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapte
|
|||
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
|
||||
ServiceIdentifier uuid,
|
||||
|
||||
StaleDevices devices) {
|
||||
StaleDevicesResponse devices) {
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record MismatchedDevices(@JsonProperty
|
||||
@Schema(description = "Devices present on the account but absent in the request")
|
||||
List<Byte> missingDevices,
|
||||
|
||||
@JsonProperty
|
||||
@Schema(description = "Devices absent on the request but present in the account")
|
||||
List<Byte> extraDevices) {
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.entities;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public record MismatchedDevicesResponse(@JsonProperty
|
||||
@Schema(description = "Devices present on the account but absent in the request")
|
||||
Set<Byte> missingDevices,
|
||||
|
||||
@JsonProperty
|
||||
@Schema(description = "Devices absent on the request but present in the account")
|
||||
Set<Byte> extraDevices) {
|
||||
}
|
|
@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.entities;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public record StaleDevices(@JsonProperty
|
||||
@Schema(description = "Devices that are no longer active")
|
||||
List<Byte> staleDevices) {
|
||||
public record StaleDevicesResponse(@JsonProperty
|
||||
@Schema(description = "Devices that are no longer active")
|
||||
Set<Byte> staleDevices) {
|
||||
}
|
|
@ -11,13 +11,24 @@ import com.google.common.annotations.VisibleForTesting;
|
|||
import io.dropwizard.util.DataSize;
|
||||
import io.micrometer.core.instrument.DistributionSummary;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import io.micrometer.core.instrument.Tag;
|
||||
import io.micrometer.core.instrument.Tags;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
|
@ -58,6 +69,9 @@ public class MessageSender {
|
|||
public static final int MAX_MESSAGE_SIZE = (int) DataSize.kibibytes(256).toBytes();
|
||||
private static final long LARGE_MESSAGE_SIZE = DataSize.kibibytes(8).toBytes();
|
||||
|
||||
@VisibleForTesting
|
||||
static final byte NO_EXCLUDED_DEVICE_ID = -1;
|
||||
|
||||
public MessageSender(final MessagesManager messagesManager, final PushNotificationManager pushNotificationManager) {
|
||||
this.messagesManager = messagesManager;
|
||||
this.pushNotificationManager = pushNotificationManager;
|
||||
|
@ -68,23 +82,51 @@ public class MessageSender {
|
|||
* notification token and does not have an active connection to a Signal server, then this method will also send a
|
||||
* push notification to that device to announce the availability of new messages.
|
||||
*
|
||||
* @param account the account to which to send messages
|
||||
* @param destination the account to which to send messages
|
||||
* @param destinationIdentifier the service identifier to which the messages are addressed
|
||||
* @param messagesByDeviceId a map of device IDs to message payloads
|
||||
* @param registrationIdsByDeviceId a map of device IDs to device registration IDs
|
||||
*/
|
||||
public void sendMessages(final Account account, final Map<Byte, Envelope> messagesByDeviceId) {
|
||||
messagesManager.insert(account.getIdentifier(IdentityType.ACI), messagesByDeviceId)
|
||||
public void sendMessages(final Account destination,
|
||||
final ServiceIdentifier destinationIdentifier,
|
||||
final Map<Byte, Envelope> messagesByDeviceId,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId) throws MismatchedDevicesException {
|
||||
|
||||
if (messagesByDeviceId.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!destination.isIdentifiedBy(destinationIdentifier)) {
|
||||
throw new IllegalArgumentException("Destination account not identified by destination service identifier");
|
||||
}
|
||||
|
||||
final Envelope firstMessage = messagesByDeviceId.values().iterator().next();
|
||||
|
||||
final boolean isSyncMessage = StringUtils.isNotBlank(firstMessage.getSourceServiceId()) &&
|
||||
destination.isIdentifiedBy(ServiceIdentifier.valueOf(firstMessage.getSourceServiceId()));
|
||||
|
||||
final Optional<MismatchedDevices> maybeMismatchedDevices = getMismatchedDevices(destination,
|
||||
destinationIdentifier,
|
||||
registrationIdsByDeviceId,
|
||||
isSyncMessage ? (byte) firstMessage.getSourceDevice() : NO_EXCLUDED_DEVICE_ID);
|
||||
|
||||
if (maybeMismatchedDevices.isPresent()) {
|
||||
throw new MismatchedDevicesException(maybeMismatchedDevices.get());
|
||||
}
|
||||
|
||||
messagesManager.insert(destination.getIdentifier(IdentityType.ACI), messagesByDeviceId)
|
||||
.forEach((deviceId, destinationPresent) -> {
|
||||
final Envelope message = messagesByDeviceId.get(deviceId);
|
||||
|
||||
if (!destinationPresent && !message.getEphemeral()) {
|
||||
try {
|
||||
pushNotificationManager.sendNewMessageNotification(account, deviceId, message.getUrgent());
|
||||
pushNotificationManager.sendNewMessageNotification(destination, deviceId, message.getUrgent());
|
||||
} catch (final NotPushRegisteredException ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
Metrics.counter(SEND_COUNTER_NAME,
|
||||
CHANNEL_TAG_NAME, account.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
|
||||
CHANNEL_TAG_NAME, destination.getDevice(deviceId).map(MessageSender::getDeliveryChannelName).orElse("unknown"),
|
||||
EPHEMERAL_TAG_NAME, String.valueOf(message.getEphemeral()),
|
||||
CLIENT_ONLINE_TAG_NAME, String.valueOf(destinationPresent),
|
||||
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
|
||||
|
@ -98,6 +140,10 @@ public class MessageSender {
|
|||
* Sends messages to a group of recipients. If a destination device has a valid push notification token and does not
|
||||
* have an active connection to a Signal server, then this method will also send a push notification to that device to
|
||||
* announce the availability of new messages.
|
||||
* <p>
|
||||
* This method sends messages to all <em>resolved</em> recipients. In some cases, a caller may not be able to resolve
|
||||
* all recipients to active accounts, but may still choose to send the message. Callers are responsible for rejecting
|
||||
* the message if they require full resolution of all recipients, but some recipients could not be resolved.
|
||||
*
|
||||
* @param multiRecipientMessage the multi-recipient message to send to the given recipients
|
||||
* @param resolvedRecipients a map of recipients to resolved Signal accounts
|
||||
|
@ -114,7 +160,31 @@ public class MessageSender {
|
|||
final long clientTimestamp,
|
||||
final boolean isStory,
|
||||
final boolean isEphemeral,
|
||||
final boolean isUrgent) {
|
||||
final boolean isUrgent) throws MultiRecipientMismatchedDevicesException {
|
||||
|
||||
final Map<ServiceIdentifier, MismatchedDevices> mismatchedDevicesByServiceIdentifier = new HashMap<>();
|
||||
|
||||
multiRecipientMessage.getRecipients().forEach((serviceId, recipient) -> {
|
||||
if (!resolvedRecipients.containsKey(recipient)) {
|
||||
// Callers are responsible for rejecting messages if they're missing recipients in a problematic way. If we run
|
||||
// into an unresolved recipient here, just skip it.
|
||||
return;
|
||||
}
|
||||
|
||||
final Account account = resolvedRecipients.get(recipient);
|
||||
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromLibsignal(serviceId);
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = recipient.getDevicesAndRegistrationIds()
|
||||
.collect(Collectors.toMap(Pair::first, pair -> (int) pair.second()));
|
||||
|
||||
getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, NO_EXCLUDED_DEVICE_ID)
|
||||
.ifPresent(mismatchedDevices ->
|
||||
mismatchedDevicesByServiceIdentifier.put(serviceIdentifier, mismatchedDevices));
|
||||
});
|
||||
|
||||
if (!mismatchedDevicesByServiceIdentifier.isEmpty()) {
|
||||
throw new MultiRecipientMismatchedDevicesException(mismatchedDevicesByServiceIdentifier);
|
||||
}
|
||||
|
||||
return messagesManager.insertMultiRecipientMessage(multiRecipientMessage, resolvedRecipients, clientTimestamp,
|
||||
isStory, isEphemeral, isUrgent)
|
||||
|
@ -189,4 +259,47 @@ public class MessageSender {
|
|||
.increment();
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static Optional<MismatchedDevices> getMismatchedDevices(final Account account,
|
||||
final ServiceIdentifier serviceIdentifier,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
final byte excludedDeviceId) {
|
||||
|
||||
final Set<Byte> accountDeviceIds = account.getDevices().stream()
|
||||
.map(Device::getId)
|
||||
.filter(deviceId -> deviceId != excludedDeviceId)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
|
||||
missingDeviceIds.removeAll(registrationIdsByDeviceId.keySet());
|
||||
|
||||
final Set<Byte> extraDeviceIds = new HashSet<>(registrationIdsByDeviceId.keySet());
|
||||
extraDeviceIds.removeAll(accountDeviceIds);
|
||||
|
||||
final Set<Byte> staleDeviceIds = registrationIdsByDeviceId.entrySet().stream()
|
||||
// Filter out device IDs that aren't associated with the given account
|
||||
.filter(entry -> !extraDeviceIds.contains(entry.getKey()))
|
||||
.filter(entry -> {
|
||||
final byte deviceId = entry.getKey();
|
||||
final int registrationId = entry.getValue();
|
||||
|
||||
// We know the device must be present because we've already filtered out device IDs that aren't associated
|
||||
// with the given account
|
||||
final Device device = account.getDevice(deviceId).orElseThrow();
|
||||
|
||||
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
};
|
||||
|
||||
return registrationId != expectedRegistrationId;
|
||||
})
|
||||
.map(Map.Entry::getKey)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
return (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty() || !staleDeviceIds.isEmpty())
|
||||
? Optional.of(new MismatchedDevices(missingDeviceIds, extraDeviceIds, staleDeviceIds))
|
||||
: Optional.empty();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.push;
|
|||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.stream.Collectors;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -54,9 +55,20 @@ public class ReceiptSender {
|
|||
.setUrgent(false)
|
||||
.build();
|
||||
|
||||
final Map<Byte, Envelope> messagesByDeviceId = destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, ignored -> message));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId);
|
||||
}));
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destinationAccount, destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, ignored -> message)));
|
||||
messageSender.sendMessages(destinationAccount,
|
||||
destinationIdentifier,
|
||||
messagesByDeviceId,
|
||||
registrationIdsByDeviceId);
|
||||
} catch (final Exception e) {
|
||||
logger.warn("Could not send delivery receipt", e);
|
||||
}
|
||||
|
|
|
@ -37,10 +37,12 @@ import java.util.Arrays;
|
|||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Queue;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -55,6 +57,7 @@ import java.util.function.BiFunction;
|
|||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.Mac;
|
||||
|
@ -67,6 +70,7 @@ import org.slf4j.LoggerFactory;
|
|||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
|
||||
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
|
||||
|
@ -83,7 +87,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
|||
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
|
||||
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
|
||||
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
|
||||
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
@ -780,24 +783,33 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
|
|||
}
|
||||
|
||||
// Check that all including primary ID are in signed pre-keys
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(
|
||||
account,
|
||||
pniSignedPreKeys.keySet(),
|
||||
Collections.emptySet());
|
||||
validateCompleteDeviceList(account, pniSignedPreKeys.keySet());
|
||||
|
||||
// Check that all including primary ID are in Pq pre-keys
|
||||
if (pniPqLastResortPreKeys != null) {
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(
|
||||
account,
|
||||
pniPqLastResortPreKeys.keySet(),
|
||||
Collections.emptySet());
|
||||
validateCompleteDeviceList(account, pniPqLastResortPreKeys.keySet());
|
||||
}
|
||||
|
||||
// Check that all devices are accounted for in the map of new PNI registration IDs
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(
|
||||
account,
|
||||
pniRegistrationIds.keySet(),
|
||||
Collections.emptySet());
|
||||
validateCompleteDeviceList(account, pniRegistrationIds.keySet());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds) throws MismatchedDevicesException {
|
||||
final Set<Byte> accountDeviceIds = account.getDevices().stream()
|
||||
.map(Device::getId)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
|
||||
missingDeviceIds.removeAll(deviceIds);
|
||||
|
||||
final Set<Byte> extraDeviceIds = new HashSet<>(deviceIds);
|
||||
extraDeviceIds.removeAll(accountDeviceIds);
|
||||
|
||||
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
|
||||
throw new MismatchedDevicesException(
|
||||
new MismatchedDevices(missingDeviceIds, extraDeviceIds, Collections.emptySet()));
|
||||
}
|
||||
}
|
||||
|
||||
public record UsernameReservation(Account account, byte[] reservedUsernameHash){}
|
||||
|
|
|
@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.storage;
|
|||
import com.google.protobuf.ByteString;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.Nullable;
|
||||
import org.apache.commons.lang3.ObjectUtils;
|
||||
|
@ -15,15 +14,15 @@ import org.signal.libsignal.protocol.IdentityKey;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
|
||||
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
||||
|
||||
public class ChangeNumberManager {
|
||||
|
||||
|
@ -45,12 +44,10 @@ public class ChangeNumberManager {
|
|||
@Nullable final List<IncomingMessage> deviceMessages,
|
||||
@Nullable final Map<Byte, Integer> pniRegistrationIds,
|
||||
@Nullable final String senderUserAgent)
|
||||
throws InterruptedException, MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
|
||||
throws InterruptedException, MismatchedDevicesException, MessageTooLargeException {
|
||||
|
||||
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
|
||||
// AccountsManager validates the device set on deviceSignedPreKeys and pniRegistrationIds
|
||||
validateDeviceMessages(account, deviceMessages);
|
||||
} else if (!ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
|
||||
if (!(ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds) ||
|
||||
ObjectUtils.allNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds))) {
|
||||
throw new IllegalArgumentException("PNI identity key, signed pre-keys, device messages, and registration IDs must be all null or all non-null");
|
||||
}
|
||||
|
||||
|
@ -84,9 +81,7 @@ public class ChangeNumberManager {
|
|||
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
|
||||
final List<IncomingMessage> deviceMessages,
|
||||
final Map<Byte, Integer> pniRegistrationIds,
|
||||
final String senderUserAgent) throws MismatchedDevicesException, StaleDevicesException, MessageTooLargeException {
|
||||
|
||||
validateDeviceMessages(account, deviceMessages);
|
||||
final String senderUserAgent) throws MismatchedDevicesException, MessageTooLargeException {
|
||||
|
||||
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
|
||||
// write anyway. Linked devices can handle some wasted extra key rotations.
|
||||
|
@ -97,26 +92,9 @@ public class ChangeNumberManager {
|
|||
return updatedAccount;
|
||||
}
|
||||
|
||||
private void validateDeviceMessages(final Account account,
|
||||
final List<IncomingMessage> deviceMessages) throws MismatchedDevicesException, StaleDevicesException {
|
||||
// Check that all except primary ID are in device messages
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(
|
||||
account,
|
||||
deviceMessages.stream().map(IncomingMessage::destinationDeviceId).collect(Collectors.toSet()),
|
||||
Set.of(Device.PRIMARY_ID));
|
||||
|
||||
// check that all sync messages are to the current registration ID for the matching device
|
||||
DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
deviceMessages,
|
||||
IncomingMessage::destinationDeviceId,
|
||||
IncomingMessage::destinationRegistrationId,
|
||||
false);
|
||||
}
|
||||
|
||||
private void sendDeviceMessages(final Account account,
|
||||
final List<IncomingMessage> deviceMessages,
|
||||
final String senderUserAgent) throws MessageTooLargeException {
|
||||
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
|
||||
|
||||
for (final IncomingMessage message : deviceMessages) {
|
||||
MessageSender.validateContentLength(message.content().length,
|
||||
|
@ -128,20 +106,26 @@ public class ChangeNumberManager {
|
|||
|
||||
try {
|
||||
final long serverTimestamp = System.currentTimeMillis();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
|
||||
|
||||
messageSender.sendMessages(account, deviceMessages.stream()
|
||||
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
|
||||
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, message -> Envelope.newBuilder()
|
||||
.setType(Envelope.Type.forNumber(message.type()))
|
||||
.setClientTimestamp(serverTimestamp)
|
||||
.setServerTimestamp(serverTimestamp)
|
||||
.setDestinationServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
|
||||
.setDestinationServiceId(serviceIdentifier.toServiceIdentifierString())
|
||||
.setContent(ByteString.copyFrom(message.content()))
|
||||
.setSourceServiceId(new AciServiceIdentifier(account.getUuid()).toServiceIdentifierString())
|
||||
.setSourceServiceId(serviceIdentifier.toServiceIdentifierString())
|
||||
.setSourceDevice(Device.PRIMARY_ID)
|
||||
.setUpdatedPni(account.getPhoneNumberIdentifier().toString())
|
||||
.setUrgent(true)
|
||||
.setEphemeral(false)
|
||||
.build())));
|
||||
.build()));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
|
||||
|
||||
messageSender.sendMessages(account, serviceIdentifier, messagesByDeviceId, registrationIdsByDeviceId);
|
||||
} catch (final RuntimeException e) {
|
||||
logger.warn("Changed number but could not send all device messages on {}", account.getUuid(), e);
|
||||
throw e;
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2022 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
public class DestinationDeviceValidator {
|
||||
|
||||
/**
|
||||
* @see #validateRegistrationIds(Account, Stream, boolean)
|
||||
*/
|
||||
public static <T> void validateRegistrationIds(final Account account,
|
||||
final Collection<T> messages,
|
||||
Function<T, Byte> getDeviceId,
|
||||
Function<T, Integer> getRegistrationId,
|
||||
boolean usePhoneNumberIdentity) throws StaleDevicesException {
|
||||
|
||||
validateRegistrationIds(account,
|
||||
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
|
||||
usePhoneNumberIdentity);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the given device ID/registration ID pairs exactly match the corresponding device ID/registration ID
|
||||
* pairs in the given destination account. This method does <em>not</em> validate that all devices associated with the
|
||||
* destination account are present in the given device ID/registration ID pairs.
|
||||
*
|
||||
* @param account the destination account against which to check the given device
|
||||
* ID/registration ID pairs
|
||||
* @param deviceIdAndRegistrationIdStream a stream of device ID and registration ID pairs
|
||||
* @param usePhoneNumberIdentity if {@code true}, compare provided registration IDs against device
|
||||
* registration IDs associated with the account's PNI (if available); compare
|
||||
* against the ACI-associated registration ID otherwise
|
||||
* @throws StaleDevicesException if the device ID/registration ID pairs contained an entry for which the destination
|
||||
* account does not have a corresponding device or if the registration IDs do not match
|
||||
*/
|
||||
public static void validateRegistrationIds(final Account account,
|
||||
final Stream<Pair<Byte, Integer>> deviceIdAndRegistrationIdStream,
|
||||
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
|
||||
|
||||
final List<Byte> staleDevices = deviceIdAndRegistrationIdStream
|
||||
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
|
||||
.filter(deviceIdAndRegistrationId -> {
|
||||
final byte deviceId = deviceIdAndRegistrationId.first();
|
||||
final int registrationId = deviceIdAndRegistrationId.second();
|
||||
boolean registrationIdMatches = account.getDevice(deviceId)
|
||||
.map(device -> registrationId == (usePhoneNumberIdentity
|
||||
? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId())
|
||||
: device.getRegistrationId()))
|
||||
.orElse(false);
|
||||
return !registrationIdMatches;
|
||||
})
|
||||
.map(Pair::first)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (!staleDevices.isEmpty()) {
|
||||
throw new StaleDevicesException(staleDevices);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the given set of device IDs from a set of messages matches the set of device IDs associated with the
|
||||
* given destination account in preparation for sending those messages to the destination account. In general, the set
|
||||
* of device IDs must exactly match the set of active devices associated with the destination account. When sending a
|
||||
* "sync," message, though, the authenticated account is sending messages from one of their devices to all other
|
||||
* devices; in that case, callers must pass the ID of the sending device in the set of {@code excludedDeviceIds}.
|
||||
*
|
||||
* @param account the destination account against which to check the given set of device IDs
|
||||
* @param messageDeviceIds the set of device IDs to check against the destination account
|
||||
* @param excludedDeviceIds a set of device IDs that may be associated with the destination account, but must not be
|
||||
* present in the given set of device IDs (i.e. the device that is sending a sync message)
|
||||
* @throws MismatchedDevicesException if the given set of device IDs contains entries not currently associated with
|
||||
* the destination account or is missing entries associated with the destination
|
||||
* account
|
||||
*/
|
||||
public static void validateCompleteDeviceList(final Account account,
|
||||
final Set<Byte> messageDeviceIds,
|
||||
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
|
||||
|
||||
final Set<Byte> accountDeviceIds = account.getDevices().stream()
|
||||
.map(Device::getId)
|
||||
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
|
||||
missingDeviceIds.removeAll(messageDeviceIds);
|
||||
|
||||
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
|
||||
extraDeviceIds.removeAll(accountDeviceIds);
|
||||
|
||||
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {
|
||||
throw new MismatchedDevicesException(new ArrayList<>(missingDeviceIds), new ArrayList<>(extraDeviceIds));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
|
|||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.anyBoolean;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.reset;
|
||||
|
@ -76,12 +77,12 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
|||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
|
||||
import org.whispersystems.textsecuregcm.entities.SpamReport;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||
import org.whispersystems.textsecuregcm.entities.StaleDevicesResponse;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
|
@ -120,6 +121,7 @@ import org.whispersystems.websocket.WebsocketHeaders;
|
|||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class MessageControllerTest {
|
||||
|
@ -195,7 +197,7 @@ class MessageControllerTest {
|
|||
.build();
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
void setup() throws MultiRecipientMismatchedDevicesException {
|
||||
reset(pushNotificationScheduler);
|
||||
|
||||
when(messageSender.sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
|
@ -287,7 +289,7 @@ class MessageControllerTest {
|
|||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
|
@ -332,7 +334,7 @@ class MessageControllerTest {
|
|||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
|
@ -357,7 +359,7 @@ class MessageControllerTest {
|
|||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
|
@ -395,7 +397,7 @@ class MessageControllerTest {
|
|||
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
|
@ -434,7 +436,7 @@ class MessageControllerTest {
|
|||
assertThat("Good Response", response.getStatus(), is(equalTo(expectedResponse)));
|
||||
if (expectedResponse == 200) {
|
||||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> captor = ArgumentCaptor.forClass(Map.class);
|
||||
verify(messageSender).sendMessages(any(), captor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), captor.capture(), any());
|
||||
|
||||
assertEquals(1, captor.getValue().size());
|
||||
final Envelope message = captor.getValue().values().stream().findFirst().orElseThrow();
|
||||
|
@ -530,6 +532,9 @@ class MessageControllerTest {
|
|||
|
||||
@Test
|
||||
void testMultiDeviceMissing() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2, (byte) 3), Collections.emptySet(), Collections.emptySet())))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
|
@ -542,15 +547,16 @@ class MessageControllerTest {
|
|||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(MismatchedDevices.class)),
|
||||
asJson(response.readEntity(MismatchedDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/missing_device_response.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiDeviceExtra() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Set.of((byte) 2), Set.of((byte) 4), Collections.emptySet())))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest()
|
||||
.target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
|
@ -563,10 +569,8 @@ class MessageControllerTest {
|
|||
assertThat("Good Response Code", response.getStatus(), is(equalTo(409)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(MismatchedDevices.class)),
|
||||
asJson(response.readEntity(MismatchedDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/missing_device_response2.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -602,7 +606,7 @@ class MessageControllerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(3, envelopeCaptor.getValue().size());
|
||||
|
||||
|
@ -626,7 +630,7 @@ class MessageControllerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(Account.class), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(3, envelopeCaptor.getValue().size());
|
||||
|
||||
|
@ -648,12 +652,17 @@ class MessageControllerTest {
|
|||
assertThat("Good Response Code", response.getStatus(), is(equalTo(200)));
|
||||
|
||||
verify(messageSender).sendMessages(any(Account.class),
|
||||
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3));
|
||||
any(),
|
||||
argThat(messagesByDeviceId -> messagesByDeviceId.size() == 3),
|
||||
any());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testRegistrationIdMismatch() throws Exception {
|
||||
doThrow(new MismatchedDevicesException(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) 2))))
|
||||
.when(messageSender).sendMessages(any(), any(), any(), any());
|
||||
|
||||
try (final Response response =
|
||||
resources.getJerseyTest().target(String.format("/v1/messages/%s", MULTI_DEVICE_UUID))
|
||||
.request()
|
||||
|
@ -665,10 +674,8 @@ class MessageControllerTest {
|
|||
assertThat("Good Response Code", response.getStatus(), is(equalTo(410)));
|
||||
|
||||
assertThat("Good Response Body",
|
||||
asJson(response.readEntity(StaleDevices.class)),
|
||||
asJson(response.readEntity(StaleDevicesResponse.class)),
|
||||
is(equalTo(jsonFixture("fixtures/mismatched_registration_id.json"))));
|
||||
|
||||
verifyNoMoreInteractions(messageSender);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1078,7 +1085,7 @@ class MessageControllerTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
void testValidateContentLength() {
|
||||
void testValidateContentLength() throws MismatchedDevicesException {
|
||||
final int contentLength = Math.toIntExact(MessageSender.MAX_MESSAGE_SIZE + 1);
|
||||
final byte[] contentBytes = new byte[contentLength];
|
||||
Arrays.fill(contentBytes, (byte) 1);
|
||||
|
@ -1095,7 +1102,7 @@ class MessageControllerTest {
|
|||
|
||||
assertThat("Bad response", response.getStatus(), is(equalTo(413)));
|
||||
|
||||
verify(messageSender, never()).sendMessages(any(), any());
|
||||
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1113,10 +1120,10 @@ class MessageControllerTest {
|
|||
|
||||
if (expectOk) {
|
||||
assertEquals(200, response.getStatus());
|
||||
verify(messageSender).sendMessages(any(), any());
|
||||
verify(messageSender).sendMessages(any(), any(), any(), any());
|
||||
} else {
|
||||
assertEquals(422, response.getStatus());
|
||||
verify(messageSender, never()).sendMessages(any(), any());
|
||||
verify(messageSender, never()).sendMessages(any(), any(), any(), any());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1140,7 +1147,9 @@ class MessageControllerTest {
|
|||
final Optional<String> maybeGroupSendToken,
|
||||
final int expectedStatus,
|
||||
final Set<Account> expectedResolvedAccounts,
|
||||
final Set<ServiceIdentifier> expectedUuids404) {
|
||||
final Set<ServiceIdentifier> expectedUuids404,
|
||||
@Nullable final MultiRecipientMismatchedDevicesException mismatchedDevicesException)
|
||||
throws MultiRecipientMismatchedDevicesException {
|
||||
|
||||
clock.pin(START_OF_DAY);
|
||||
|
||||
|
@ -1151,6 +1160,11 @@ class MessageControllerTest {
|
|||
when(accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
|
||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)))));
|
||||
|
||||
if (mismatchedDevicesException != null) {
|
||||
doThrow(mismatchedDevicesException)
|
||||
.when(messageSender).sendMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean());
|
||||
}
|
||||
|
||||
final boolean ephemeral = true;
|
||||
final boolean urgent = false;
|
||||
|
||||
|
@ -1187,7 +1201,7 @@ class MessageControllerTest {
|
|||
assertThat(Set.copyOf(entity.uuids404()), equalTo(expectedUuids404));
|
||||
}
|
||||
|
||||
if (expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) {
|
||||
if ((expectedStatus == 200 && !expectedResolvedAccounts.isEmpty()) || mismatchedDevicesException != null) {
|
||||
verify(messageSender).sendMultiRecipientMessage(any(),
|
||||
argThat(resolvedRecipients ->
|
||||
new HashSet<>(resolvedRecipients.values()).equals(expectedResolvedAccounts)),
|
||||
|
@ -1267,7 +1281,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message with combined UAKs",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1279,7 +1294,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1291,7 +1307,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Incorrect combined UAK",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1303,7 +1320,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Incorrect group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1317,7 +1335,8 @@ class MessageControllerTest {
|
|||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
// Stories don't require credentials of any kind, but for historical reasons, we don't reject a combined UAK if
|
||||
// provided
|
||||
|
@ -1331,7 +1350,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1343,7 +1363,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Conflicting credentials",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1355,7 +1376,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("No credentials",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1367,7 +1389,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Oversized payload",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1383,7 +1406,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
413,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Negative timestamp",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1395,7 +1419,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Excessive timestamp",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1407,7 +1432,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Empty recipient list",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1421,7 +1447,8 @@ class MessageControllerTest {
|
|||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story with empty recipient list",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1433,7 +1460,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Duplicate recipient",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1447,7 +1475,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
400,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing account",
|
||||
Map.of(),
|
||||
|
@ -1459,7 +1488,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Collections.emptySet(),
|
||||
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci))),
|
||||
Set.of(new AciServiceIdentifier(singleDeviceAccountAci), new AciServiceIdentifier(multiDeviceAccountAci)),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("One missing and one existing account",
|
||||
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
|
||||
|
@ -1471,7 +1501,8 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
200,
|
||||
Set.of(singleDeviceAccount),
|
||||
Set.of(new AciServiceIdentifier(multiDeviceAccountAci))),
|
||||
Set.of(new AciServiceIdentifier(multiDeviceAccountAci)),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing account for story",
|
||||
Map.of(),
|
||||
|
@ -1483,7 +1514,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Collections.emptySet(),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("One missing and one existing account for story",
|
||||
Map.of(new AciServiceIdentifier(singleDeviceAccountAci), singleDeviceAccount),
|
||||
|
@ -1495,7 +1527,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Missing device",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1509,7 +1542,9 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
409,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Set.of((byte) (Device.PRIMARY_ID + 1)), Collections.emptySet(), Collections.emptySet())))),
|
||||
|
||||
Arguments.argumentSet("Extra device",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1525,7 +1560,9 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
409,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 2)), Collections.emptySet())))),
|
||||
|
||||
Arguments.argumentSet("Stale registration ID",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1540,7 +1577,9 @@ class MessageControllerTest {
|
|||
Optional.of(groupSendEndorsement),
|
||||
410,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
new MultiRecipientMismatchedDevicesException(Map.of(new AciServiceIdentifier(multiDeviceAccountAci),
|
||||
new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of((byte) (Device.PRIMARY_ID + 1)))))),
|
||||
|
||||
Arguments.argumentSet("Rate-limited story",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1552,7 +1591,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
429,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Story to PNI recipients",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1567,7 +1607,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message to PNI recipients with UAK",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1582,7 +1623,8 @@ class MessageControllerTest {
|
|||
Optional.empty(),
|
||||
401,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of()),
|
||||
Set.of(),
|
||||
null),
|
||||
|
||||
Arguments.argumentSet("Multi-recipient message to PNI recipients with group send endorsement",
|
||||
accountsByServiceIdentifier,
|
||||
|
@ -1599,7 +1641,8 @@ class MessageControllerTest {
|
|||
START_OF_DAY.plus(Duration.ofDays(1)))),
|
||||
200,
|
||||
Set.of(singleDeviceAccount, multiDeviceAccount),
|
||||
Set.of())
|
||||
Set.of(),
|
||||
null)
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,9 @@ import java.util.ArrayList;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
|
@ -30,12 +33,22 @@ import org.junit.jupiter.params.ParameterizedTest;
|
|||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MultiRecipientMismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.MultiRecipientMessageHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestRecipient;
|
||||
|
||||
class MessageSenderTest {
|
||||
|
||||
|
@ -60,7 +73,9 @@ class MessageSenderTest {
|
|||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
@ -71,7 +86,11 @@ class MessageSenderTest {
|
|||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
@ -82,7 +101,10 @@ class MessageSenderTest {
|
|||
|
||||
when(messagesManager.insert(any(), any())).thenReturn(Map.of(deviceId, clientPresent));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account, Map.of(device.getId(), message)));
|
||||
assertDoesNotThrow(() -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId)));
|
||||
|
||||
final MessageProtos.Envelope expectedMessage = ephemeral
|
||||
? message.toBuilder().setEphemeral(true).build()
|
||||
|
@ -97,23 +119,61 @@ class MessageSenderTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMessageMismatchedDevices() {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
final MessageProtos.Envelope message = MessageProtos.Envelope.newBuilder().build();
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final MismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MismatchedDevicesException.class, () -> messageSender.sendMessages(account,
|
||||
serviceIdentifier,
|
||||
Map.of(device.getId(), message),
|
||||
Map.of(device.getId(), registrationId + 1)));
|
||||
|
||||
assertEquals(new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId)),
|
||||
mismatchedDevicesException.getMismatchedDevices());
|
||||
}
|
||||
|
||||
@CartesianTest
|
||||
void sendMultiRecipientMessage(@CartesianTest.Values(booleans = {true, false}) final boolean clientPresent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean ephemeral,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean urgent,
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken) throws NotPushRegisteredException {
|
||||
@CartesianTest.Values(booleans = {true, false}) final boolean hasPushToken)
|
||||
throws NotPushRegisteredException, InvalidMessageException, InvalidVersionException {
|
||||
|
||||
final boolean expectPushNotificationAttempt = !clientPresent && !ephemeral;
|
||||
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
@ -125,12 +185,19 @@ class MessageSenderTest {
|
|||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, clientPresent))));
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(mock(SealedSenderMultiRecipientMessage.class),
|
||||
Collections.emptyMap(),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
assertDoesNotThrow(() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
ephemeral,
|
||||
urgent)
|
||||
.join());
|
||||
|
||||
if (expectPushNotificationAttempt) {
|
||||
|
@ -140,6 +207,49 @@ class MessageSenderTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void sendMultiRecipientMessageMismatchedDevices() throws InvalidMessageException, InvalidVersionException {
|
||||
final UUID accountIdentifier = UUID.randomUUID();
|
||||
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(accountIdentifier);
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final int registrationId = 17;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(accountIdentifier);
|
||||
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
|
||||
when(account.isIdentifiedBy(serviceIdentifier)).thenReturn(true);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
SealedSenderMultiRecipientMessage.parse(MultiRecipientMessageHelper.generateMultiRecipientMessage(
|
||||
List.of(new TestRecipient(serviceIdentifier, deviceId, registrationId + 1, new byte[48]))));
|
||||
|
||||
final SealedSenderMultiRecipientMessage.Recipient recipient =
|
||||
multiRecipientMessage.getRecipients().values().iterator().next();
|
||||
|
||||
when(messagesManager.insertMultiRecipientMessage(any(), any(), anyLong(), anyBoolean(), anyBoolean(), anyBoolean()))
|
||||
.thenReturn(CompletableFuture.completedFuture(Map.of(account, Map.of(deviceId, true))));
|
||||
|
||||
final MultiRecipientMismatchedDevicesException mismatchedDevicesException =
|
||||
assertThrows(MultiRecipientMismatchedDevicesException.class,
|
||||
() -> messageSender.sendMultiRecipientMessage(multiRecipientMessage,
|
||||
Map.of(recipient, account),
|
||||
System.currentTimeMillis(),
|
||||
false,
|
||||
false,
|
||||
true)
|
||||
.join());
|
||||
|
||||
assertEquals(Map.of(serviceIdentifier, new MismatchedDevices(Collections.emptySet(), Collections.emptySet(), Set.of(deviceId))),
|
||||
mismatchedDevicesException.getMismatchedDevicesByServiceIdentifier());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getDeliveryChannelName(final Device device, final String expectedChannelName) {
|
||||
|
@ -183,4 +293,87 @@ class MessageSenderTest {
|
|||
assertDoesNotThrow(() ->
|
||||
MessageSender.validateContentLength(MessageSender.MAX_MESSAGE_SIZE, false, false, false, null));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void getMismatchedDevices(final Account account,
|
||||
final ServiceIdentifier serviceIdentifier,
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
final byte excludedDeviceId,
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") final Optional<MismatchedDevices> expectedMismatchedDevices) {
|
||||
|
||||
assertEquals(expectedMismatchedDevices,
|
||||
MessageSender.getMismatchedDevices(account, serviceIdentifier, registrationIdsByDeviceId, excludedDeviceId));
|
||||
}
|
||||
|
||||
private static List<Arguments> getMismatchedDevices() {
|
||||
final byte primaryDeviceId = Device.PRIMARY_ID;
|
||||
final byte linkedDeviceId = primaryDeviceId + 1;
|
||||
final byte extraDeviceId = linkedDeviceId + 1;
|
||||
|
||||
final int primaryDeviceAciRegistrationId = 2;
|
||||
final int primaryDevicePniRegistrationId = 3;
|
||||
final int linkedDeviceAciRegistrationId = 5;
|
||||
final int linkedDevicePniRegistrationId = 7;
|
||||
|
||||
final Device primaryDevice = mock(Device.class);
|
||||
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
|
||||
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
|
||||
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId));
|
||||
|
||||
final Device linkedDevice = mock(Device.class);
|
||||
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
|
||||
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId));
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
|
||||
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
|
||||
when(account.getDevice(primaryDeviceId)).thenReturn(Optional.of(primaryDevice));
|
||||
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
|
||||
|
||||
final AciServiceIdentifier aciServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
|
||||
final PniServiceIdentifier pniServiceIdentifier = new PniServiceIdentifier(UUID.randomUUID());
|
||||
|
||||
return List.of(
|
||||
Arguments.argumentSet("Complete device list for ACI, no devices excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDeviceAciRegistrationId,
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list for PNI, no devices excluded",
|
||||
account,
|
||||
pniServiceIdentifier,
|
||||
Map.of(
|
||||
primaryDeviceId, primaryDevicePniRegistrationId,
|
||||
linkedDeviceId, linkedDevicePniRegistrationId
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Complete device list, device excluded",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId
|
||||
),
|
||||
primaryDeviceId,
|
||||
Optional.empty()),
|
||||
|
||||
Arguments.argumentSet("Mismatched devices",
|
||||
account,
|
||||
aciServiceIdentifier,
|
||||
Map.of(
|
||||
linkedDeviceId, linkedDeviceAciRegistrationId + 1,
|
||||
extraDeviceId, 17
|
||||
),
|
||||
MessageSender.NO_EXCLUDED_DEVICE_ID,
|
||||
Optional.of(new MismatchedDevices(Set.of(primaryDeviceId), Set.of(extraDeviceId), Set.of(linkedDeviceId))))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,10 +60,12 @@ import java.util.function.Consumer;
|
|||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.Timeout;
|
||||
import org.junit.jupiter.api.function.Executable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
@ -76,6 +78,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
|||
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
|
||||
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
|
@ -1705,4 +1708,47 @@ class AccountsManagerTest {
|
|||
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void validateCompleteDeviceList(final Account account, final Set<Byte> deviceIds, @Nullable final MismatchedDevicesException expectedException) {
|
||||
final Executable validateCompleteDeviceListExecutable =
|
||||
() -> AccountsManager.validateCompleteDeviceList(account, deviceIds);
|
||||
|
||||
if (expectedException != null) {
|
||||
final MismatchedDevicesException caughtException =
|
||||
assertThrows(MismatchedDevicesException.class, validateCompleteDeviceListExecutable);
|
||||
|
||||
assertEquals(expectedException.getMismatchedDevices(), caughtException.getMismatchedDevices());
|
||||
} else {
|
||||
assertDoesNotThrow(validateCompleteDeviceListExecutable);
|
||||
}
|
||||
}
|
||||
|
||||
private static List<Arguments> validateCompleteDeviceList() {
|
||||
final byte deviceId = Device.PRIMARY_ID;
|
||||
final byte extraDeviceId = deviceId + 1;
|
||||
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
|
||||
return List.of(
|
||||
Arguments.of(account, Set.of(deviceId), null),
|
||||
|
||||
Arguments.of(account, Set.of(deviceId, extraDeviceId),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Collections.emptySet(), Set.of(extraDeviceId), Collections.emptySet()))),
|
||||
|
||||
Arguments.of(account, Collections.emptySet(),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Set.of(deviceId), Collections.emptySet(), Collections.emptySet()))),
|
||||
|
||||
Arguments.of(account, Set.of(extraDeviceId),
|
||||
new MismatchedDevicesException(
|
||||
new MismatchedDevices(Set.of(deviceId), Set.of((byte) (extraDeviceId)), Collections.emptySet())))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,11 +32,12 @@ import org.signal.libsignal.protocol.IdentityKey;
|
|||
import org.signal.libsignal.protocol.ecc.Curve;
|
||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
|
||||
|
@ -105,7 +106,7 @@ public class ChangeNumberManagerTest {
|
|||
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null, null);
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
|
||||
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -119,7 +120,7 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap(), null);
|
||||
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any());
|
||||
verify(messageSender, never()).sendMessages(eq(account), any(), any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -159,7 +160,7 @@ public class ChangeNumberManagerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
@ -212,7 +213,7 @@ public class ChangeNumberManagerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
@ -263,7 +264,7 @@ public class ChangeNumberManagerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
@ -310,7 +311,7 @@ public class ChangeNumberManagerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
@ -359,7 +360,7 @@ public class ChangeNumberManagerTest {
|
|||
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
|
||||
ArgumentCaptor.forClass(Map.class);
|
||||
|
||||
verify(messageSender).sendMessages(any(), envelopeCaptor.capture());
|
||||
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any());
|
||||
|
||||
assertEquals(1, envelopeCaptor.getValue().size());
|
||||
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
|
||||
|
@ -372,82 +373,6 @@ public class ChangeNumberManagerTest {
|
|||
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
|
||||
}
|
||||
|
||||
@Test
|
||||
void changeNumberMismatchedRegistrationId() {
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
for (byte i = 1; i <= 3; i++) {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(i);
|
||||
when(device.getRegistrationId()).thenReturn((int) i);
|
||||
|
||||
devices.add(device);
|
||||
when(account.getDevice(i)).thenReturn(Optional.of(device));
|
||||
}
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
||||
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
|
||||
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
|
||||
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
|
||||
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
|
||||
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
|
||||
destinationDeviceId3, 89);
|
||||
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void updatePniKeysMismatchedRegistrationId() {
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
for (byte i = 1; i <= 3; i++) {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(i);
|
||||
when(device.getRegistrationId()).thenReturn((int) i);
|
||||
|
||||
devices.add(device);
|
||||
when(account.getDevice(i)).thenReturn(Optional.of(device));
|
||||
}
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
final byte destinationDeviceId2 = 2;
|
||||
final byte destinationDeviceId3 = 3;
|
||||
final List<IncomingMessage> messages = List.of(
|
||||
new IncomingMessage(1, destinationDeviceId2, 1, "foo".getBytes(StandardCharsets.UTF_8)),
|
||||
new IncomingMessage(1, destinationDeviceId3, 1, "foo".getBytes(StandardCharsets.UTF_8)));
|
||||
|
||||
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
|
||||
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
|
||||
|
||||
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
|
||||
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
|
||||
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
|
||||
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
|
||||
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
|
||||
destinationDeviceId3, 89);
|
||||
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds, null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void changeNumberMissingData() {
|
||||
final Account account = mock(Account.class);
|
||||
|
|
|
@ -1,273 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2022 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
import org.assertj.core.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class DestinationDeviceValidatorTest {
|
||||
|
||||
static Account mockAccountWithDeviceAndRegId(final Map<Byte, Integer> registrationIdsByDeviceId) {
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
registrationIdsByDeviceId.forEach((deviceId, registrationId) -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
});
|
||||
|
||||
return account;
|
||||
}
|
||||
|
||||
static Stream<Arguments> validateRegistrationIdsSource() {
|
||||
final byte id1 = 1;
|
||||
final byte id2 = 2;
|
||||
final byte id3 = 3;
|
||||
return Stream.of(
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)),
|
||||
Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 1492),
|
||||
Set.of(id1)),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 42),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42)),
|
||||
Map.of(id1, 0),
|
||||
null),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)),
|
||||
Map.of(id1, 0, id2, 42),
|
||||
Set.of(id2)),
|
||||
arguments(
|
||||
mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)),
|
||||
Map.of(id1, 41, id2, 257),
|
||||
Set.of(id1, id2))
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("validateRegistrationIdsSource")
|
||||
void testValidateRegistrationIds(
|
||||
Account account,
|
||||
Map<Byte, Integer> registrationIdsByDeviceId,
|
||||
Set<Byte> expectedStaleDeviceIds) throws Exception {
|
||||
if (expectedStaleDeviceIds != null) {
|
||||
Assertions.assertThat(assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(
|
||||
account,
|
||||
registrationIdsByDeviceId.entrySet(),
|
||||
Map.Entry::getKey,
|
||||
Map.Entry::getValue,
|
||||
false))
|
||||
.getStaleDevices())
|
||||
.hasSameElementsAs(expectedStaleDeviceIds);
|
||||
} else {
|
||||
DestinationDeviceValidator.validateRegistrationIds(account, registrationIdsByDeviceId.entrySet(),
|
||||
Map.Entry::getKey, Map.Entry::getValue, false);
|
||||
}
|
||||
}
|
||||
|
||||
static Account mockAccountWithDeviceAndEnabled(final Map<Byte, Boolean> enabledStateByDeviceId) {
|
||||
final Account account = mock(Account.class);
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
|
||||
enabledStateByDeviceId.forEach((deviceId, enabled) -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
|
||||
devices.add(device);
|
||||
});
|
||||
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
return account;
|
||||
}
|
||||
|
||||
static Stream<Arguments> validateCompleteDeviceList() {
|
||||
final byte id1 = 1;
|
||||
final byte id2 = 2;
|
||||
final byte id3 = 3;
|
||||
|
||||
final Account account = mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true));
|
||||
|
||||
return Stream.of(
|
||||
// Device IDs provided for all enabled devices
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id3),
|
||||
Set.of(id2),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID provided for disabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id2, id3),
|
||||
null,
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID omitted for enabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1),
|
||||
Set.of(id2, id3),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID included for disabled device, omitted for enabled device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1, id2),
|
||||
Set.of(id3),
|
||||
null,
|
||||
Collections.emptySet()),
|
||||
|
||||
// Device ID omitted for enabled device, included for device in excluded list
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id1),
|
||||
Set.of(id2, id3),
|
||||
Set.of(id1),
|
||||
Set.of(id1)
|
||||
),
|
||||
|
||||
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id2),
|
||||
Set.of(id3),
|
||||
null,
|
||||
Set.of(id1)
|
||||
),
|
||||
|
||||
// Device ID included for enabled device, omitted for excluded device
|
||||
arguments(
|
||||
account,
|
||||
Set.of(id3),
|
||||
Set.of(id2),
|
||||
null,
|
||||
Set.of(id1)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void validateCompleteDeviceList(
|
||||
Account account,
|
||||
Set<Byte> deviceIds,
|
||||
Collection<Byte> expectedMissingDeviceIds,
|
||||
Collection<Byte> expectedExtraDeviceIds,
|
||||
Set<Byte> excludedDeviceIds) throws Exception {
|
||||
|
||||
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
|
||||
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds));
|
||||
if (expectedMissingDeviceIds != null) {
|
||||
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
|
||||
.hasSameElementsAs(expectedMissingDeviceIds);
|
||||
}
|
||||
if (expectedExtraDeviceIds != null) {
|
||||
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
|
||||
}
|
||||
} else {
|
||||
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, excludedDeviceIds);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDuplicateDeviceIds() {
|
||||
final Account account = mockAccountWithDeviceAndRegId(Map.of(Device.PRIMARY_ID, 17));
|
||||
try {
|
||||
DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, 16), new Pair<>(Device.PRIMARY_ID, 17)), false);
|
||||
Assertions.fail("duplicate devices should throw StaleDevicesException");
|
||||
} catch (StaleDevicesException e) {
|
||||
Assertions.assertThat(e.getStaleDevices()).hasSameElementsAs(Collections.singletonList(Device.PRIMARY_ID));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testValidatePniRegistrationIds() {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
|
||||
|
||||
final int aciRegistrationId = 17;
|
||||
final int pniRegistrationId = 89;
|
||||
final int incorrectRegistrationId = aciRegistrationId + pniRegistrationId;
|
||||
|
||||
when(device.getRegistrationId()).thenReturn(aciRegistrationId);
|
||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(pniRegistrationId));
|
||||
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)), false));
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class,
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, pniRegistrationId)),
|
||||
false));
|
||||
|
||||
when(device.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
|
||||
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
false));
|
||||
assertDoesNotThrow(
|
||||
() -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, aciRegistrationId)),
|
||||
true));
|
||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), true));
|
||||
assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds(account,
|
||||
Stream.of(new Pair<>(Device.PRIMARY_ID, incorrectRegistrationId)), false));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue