diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index c73c2f133..67d822665 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.responses.ApiResponse; import java.util.Base64; import java.util.Objects; @@ -50,6 +51,8 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameLinkHandle; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; @@ -399,6 +402,7 @@ public class AccountController { return accounts .getByUsernameHash(hash) .map(Account::getUuid) + .map(AciServiceIdentifier::new) .map(AccountIdentifierResponse::new) .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)); } @@ -485,21 +489,32 @@ public class AccountController { return new EncryptedUsername(maybeEncryptedUsername.get()); } + @Operation( + summary = "Check whether an account exists", + description = """ + Enforced unauthenticated endpoint. Checks whether an account with a given identifier exists. + """ + ) + @ApiResponse(responseCode = "200", description = "An account with the given identifier was found.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "A client made an authenticated to this endpoint, and must not provide credentials.") + @ApiResponse(responseCode = "404", description = "An account was not found for the given identifier.") + @ApiResponse(responseCode = "422", description = "Invalid request format.") + @ApiResponse(responseCode = "429", description = "Rate-limited.") @HEAD - @Path("/account/{uuid}") + @Path("/account/{identifier}") @RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE) public Response accountExists( @Auth final Optional authenticatedAccount, - @PathParam("uuid") final UUID uuid) throws RateLimitExceededException { + + @Parameter(description = "An ACI or PNI account identifier to check") + @PathParam("identifier") final ServiceIdentifier accountIdentifier) { // Disallow clients from making authenticated requests to this endpoint requireNotAuthenticated(authenticatedAccount); - final Status status = accounts.getByAccountIdentifier(uuid) - .or(() -> accounts.getByPhoneNumberIdentifier(uuid)) - .isPresent() ? Status.OK : Status.NOT_FOUND; + final Optional maybeAccount = accounts.getByServiceIdentifier(accountIdentifier); - return Response.status(status).build(); + return Response.status(maybeAccount.map(ignored -> Status.OK).orElse(Status.NOT_FOUND)).build(); } @Timed diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 897d122ed..5f326610d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -57,6 +57,7 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.experiment.Experiment; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; @@ -207,7 +208,7 @@ public class KeysController { @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @Parameter(description="the account or phone-number identifier to retrieve keys for") - @PathParam("identifier") UUID targetUuid, + @PathParam("identifier") ServiceIdentifier targetIdentifier, @Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices") @PathParam("device_id") String deviceId, @@ -227,8 +228,7 @@ public class KeysController { final Account target; { - final Optional maybeTarget = accounts.getByAccountIdentifier(targetUuid) - .or(() -> accounts.getByPhoneNumberIdentifier(targetUuid)); + final Optional maybeTarget = accounts.getByServiceIdentifier(targetIdentifier); OptionalAccess.verify(account, accessKey, maybeTarget, deviceId); @@ -237,34 +237,39 @@ public class KeysController { if (account.isPresent()) { rateLimiters.getPreKeysLimiter().validate( - account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetUuid + account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetIdentifier.uuid() + "." + deviceId); } - final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid); - List devices = parseDeviceId(deviceId, target); List responseItems = new ArrayList<>(devices.size()); for (Device device : devices) { - UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; - ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); - ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).join().orElse(null); - KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).join().orElse(null) : null; + ECSignedPreKey signedECPreKey = switch (targetIdentifier.identityType()) { + case ACI -> device.getSignedPreKey(); + case PNI -> device.getPhoneNumberIdentitySignedPreKey(); + }; + + ECPreKey unsignedECPreKey = keys.takeEC(targetIdentifier.uuid(), device.getId()).join().orElse(null); + KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()).join().orElse(null) : null; compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), - keys.getEcSignedPreKey(identifier, device.getId())); + keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())); if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { - final int registrationId = usePhoneNumberIdentity ? - device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : - device.getRegistrationId(); + final int registrationId = switch (targetIdentifier.identityType()) { + case ACI -> device.getRegistrationId(); + case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); + }; responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey)); } } - final IdentityKey identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); + final IdentityKey identityKey = switch (targetIdentifier.identityType()) { + case ACI -> target.getIdentityKey(); + case PNI -> target.getPhoneNumberIdentityKey(); + }; if (responseItems.isEmpty()) { throw new WebApplicationException(Response.Status.NOT_FOUND); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 081f4b7bd..b2787e75b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -35,7 +36,6 @@ import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nonnull; @@ -48,6 +48,7 @@ import javax.ws.rs.DELETE; import javax.ws.rs.DefaultValue; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; +import javax.ws.rs.NotFoundException; import javax.ws.rs.POST; import javax.ws.rs.PUT; import javax.ws.rs.Path; @@ -82,6 +83,8 @@ import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -183,7 +186,7 @@ public class MessageController { @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor, - @PathParam("destination") UUID destinationUuid, + @PathParam("destination") ServiceIdentifier destinationIdentifier, @QueryParam("story") boolean isStory, @NotNull @Valid IncomingMessageList messages, @Context ContainerRequestContext context) throws RateLimitExceededException { @@ -195,7 +198,7 @@ public class MessageController { final String senderType; if (source.isPresent()) { - if (source.get().getAccount().isIdentifiedBy(destinationUuid)) { + if (source.get().getAccount().isIdentifiedBy(destinationIdentifier)) { senderType = SENDER_TYPE_SELF; } else { senderType = SENDER_TYPE_IDENTIFIED; @@ -227,7 +230,7 @@ public class MessageController { } try { - rateLimiters.getInboundMessageBytes().validate(destinationUuid, totalContentLength); + rateLimiters.getInboundMessageBytes().validate(destinationIdentifier.uuid(), totalContentLength); } catch (final RateLimitExceededException e) { if (dynamicConfigurationManager.getConfiguration().getInboundMessageByteLimitConfiguration().enforceInboundLimit()) { throw e; @@ -235,13 +238,12 @@ public class MessageController { } try { - boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid); + boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationIdentifier); Optional destination; if (!isSyncMessage) { - destination = accountsManager.getByAccountIdentifier(destinationUuid) - .or(() -> accountsManager.getByPhoneNumberIdentifier(destinationUuid)); + destination = accountsManager.getByServiceIdentifier(destinationIdentifier); } else { destination = source.map(AuthenticatedAccount::getAccount); } @@ -288,7 +290,7 @@ public class MessageController { messages.messages(), IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId, - destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); + destination.get().getPhoneNumberIdentifier().equals(destinationIdentifier.uuid())); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())), @@ -303,7 +305,7 @@ public class MessageController { source, destination.get(), destinationDevice.get(), - destinationUuid, + destinationIdentifier, messages.timestamp(), messages.online(), isStory, @@ -334,25 +336,20 @@ public class MessageController { /** * Build mapping of accounts to devices/registration IDs. - * - * @param multiRecipientMessage - * @param uuidToAccountMap - * @return */ private Map>> buildDeviceIdAndRegistrationIdMap( MultiRecipientMessage multiRecipientMessage, - Map uuidToAccountMap - ) { + Map accountsByServiceIdentifier) { - return Arrays.stream(multiRecipientMessage.getRecipients()) + return Arrays.stream(multiRecipientMessage.recipients()) // for normal messages, all recipients UUIDs are in the map, // but story messages might specify inactive UUIDs, which we // have previously filtered - .filter(r -> uuidToAccountMap.containsKey(r.getUuid())) + .filter(r -> accountsByServiceIdentifier.containsKey(r.uuid())) .collect(Collectors.toMap( - recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> accountsByServiceIdentifier.get(recipient.uuid()), recipient -> new HashSet<>( - Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), + Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))), (a, b) -> { a.addAll(b); return a; @@ -376,33 +373,29 @@ public class MessageController { @QueryParam("story") boolean isStory, @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { - // we skip "missing" accounts when story=true. - // otherwise, we return a 404 status code. - final Function> accountFinder = uuid -> { - Optional res = accountsManager.getByAccountIdentifier(uuid); - if (!isStory && res.isEmpty()) { - throw new WebApplicationException(Status.NOT_FOUND); - } - return res.stream(); - }; + final Map accountsByServiceIdentifier = new HashMap<>(); - // build a map from UUID to accounts - Map uuidToAccountMap = - Arrays.stream(multiRecipientMessage.getRecipients()) - .map(Recipient::getUuid) - .distinct() - .flatMap(accountFinder) - .collect(Collectors.toUnmodifiableMap( - Account::getUuid, - Function.identity())); + for (final Recipient recipient : multiRecipientMessage.recipients()) { + if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) { + final Optional maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid()); + + if (maybeAccount.isPresent()) { + accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get()); + } else { + if (!isStory) { + throw new NotFoundException(); + } + } + } + } // Stories will be checked by the client; we bypass access checks here for stories. if (!isStory) { - checkAccessKeys(accessKeys, uuidToAccountMap); + checkAccessKeys(accessKeys, accountsByServiceIdentifier.values()); } final Map>> accountToDeviceIdAndRegistrationIdMap = - buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, uuidToAccountMap); + buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier); // We might filter out all the recipients of a story (if none have enabled stories). // In this case there is no error so we should just return 200 now. @@ -412,7 +405,7 @@ public class MessageController { Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); - uuidToAccountMap.values().forEach(account -> { + accountsByServiceIdentifier.forEach((serviceIdentifier, account) -> { if (isStory) { checkStoryRateLimit(account); @@ -434,10 +427,10 @@ public class MessageController { accountToDeviceIdAndRegistrationIdMap.get(account).stream(), false); } catch (MismatchedDevicesException e) { - accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), + accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier, new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); } catch (StaleDevicesException e) { - accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices()))); + accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices()))); } }); if (!accountMismatchedDevices.isEmpty()) { @@ -455,7 +448,7 @@ public class MessageController { .build(); } - List uuids404 = Collections.synchronizedList(new ArrayList<>()); + List uuids404 = Collections.synchronizedList(new ArrayList<>()); try { final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of( @@ -463,18 +456,18 @@ public class MessageController { Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); - multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.getRecipients()) + multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.recipients()) .map(recipient -> (Callable) () -> { - Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); + Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid()); // we asserted this must exist in validateCompleteDeviceList - Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); + Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow(); sentMessageCounter.increment(); try { sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, - recipient, multiRecipientMessage.getCommonPayload()); + recipient, multiRecipientMessage.commonPayload()); } catch (NoSuchUserException e) { - uuids404.add(destinationAccount.getUuid()); + uuids404.add(recipient.uuid()); } return null; }) @@ -486,7 +479,7 @@ public class MessageController { return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } - private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map uuidToAccountMap) { + private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinationAccounts) { // We should not have null access keys when checking access; bail out early. if (accessKeys == null) { throw new WebApplicationException(Status.UNAUTHORIZED); @@ -494,7 +487,7 @@ public class MessageController { AtomicBoolean throwUnauthorized = new AtomicBoolean(false); byte[] empty = new byte[16]; final Optional UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); - byte[] combinedUnknownAccessKeys = uuidToAccountMap.values().stream() + byte[] combinedUnknownAccessKeys = destinationAccounts.stream() .map(account -> { if (account.isUnrestrictedUnidentifiedAccess()) { return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY; @@ -595,8 +588,8 @@ public class MessageController { if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { try { receiptSender.sendReceipt( - UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), - UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); + ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), + AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); } catch (Exception e) { logger.warn("Failed to send delivery receipt", e); } @@ -663,7 +656,7 @@ public class MessageController { Optional source, Account destinationAccount, Device destinationDevice, - UUID destinationUuid, + ServiceIdentifier destinationIdentifier, long timestamp, boolean online, boolean story, @@ -679,7 +672,7 @@ public class MessageController { Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null); Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); envelope = incomingMessage.toEnvelope( - destinationUuid, + destinationIdentifier, sourceAccount, sourceDeviceId, timestamp == 0 ? System.currentTimeMillis() : timestamp, @@ -709,10 +702,10 @@ public class MessageController { try { Envelope.Builder messageBuilder = Envelope.newBuilder(); long serverTimestamp = System.currentTimeMillis(); - byte[] recipientKeyMaterial = recipient.getPerRecipientKeyMaterial(); + byte[] recipientKeyMaterial = recipient.perRecipientKeyMaterial(); byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length]; - payload[0] = MultiRecipientMessageProvider.VERSION; + payload[0] = MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER; System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length); System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length); @@ -723,7 +716,7 @@ public class MessageController { .setContent(ByteString.copyFrom(payload)) .setStory(story) .setUrgent(urgent) - .setDestinationUuid(destinationAccount.getUuid().toString()); + .setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString()); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); } catch (NotPushRegisteredException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index bf86a9a87..bb45bf98a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -90,6 +90,9 @@ import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialPro import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes; import org.whispersystems.textsecuregcm.entities.UserCapabilities; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.s3.PolicySigner; @@ -234,33 +237,33 @@ public class ProfileController { @Timed @GET @Produces(MediaType.APPLICATION_JSON) - @Path("/{uuid}/{version}") + @Path("/{identifier}/{version}") public VersionedProfileResponse getProfile( @Auth Optional auth, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @Context ContainerRequestContext containerRequestContext, - @PathParam("uuid") UUID uuid, + @PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("version") String version) throws RateLimitExceededException { final Optional maybeRequester = auth.map(AuthenticatedAccount::getAccount); - final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, uuid); + final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, accountIdentifier); return buildVersionedProfileResponse(targetAccount, version, - isSelfProfileRequest(maybeRequester, uuid), + isSelfProfileRequest(maybeRequester, accountIdentifier), containerRequestContext); } @Timed @GET @Produces(MediaType.APPLICATION_JSON) - @Path("/{uuid}/{version}/{credentialRequest}") + @Path("/{identifier}/{version}/{credentialRequest}") public CredentialProfileResponse getProfile( @Auth Optional auth, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @Context ContainerRequestContext containerRequestContext, - @PathParam("uuid") UUID uuid, + @PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("version") String version, @PathParam("credentialRequest") String credentialRequest, @QueryParam("credentialType") String credentialType) @@ -271,8 +274,8 @@ public class ProfileController { } final Optional maybeRequester = auth.map(AuthenticatedAccount::getAccount); - final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, uuid); - final boolean isSelf = isSelfProfileRequest(maybeRequester, uuid); + final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, accountIdentifier); + final boolean isSelf = isSelfProfileRequest(maybeRequester, accountIdentifier); return buildExpiringProfileKeyCredentialProfileResponse(targetAccount, version, @@ -293,34 +296,38 @@ public class ProfileController { @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, - @PathParam("identifier") UUID identifier, + @PathParam("identifier") ServiceIdentifier identifier, @QueryParam("ca") boolean useCaCertificate) throws RateLimitExceededException { - final Optional maybeAccountByPni = accountsManager.getByPhoneNumberIdentifier(identifier); final Optional maybeRequester = auth.map(AuthenticatedAccount::getAccount); - final BaseProfileResponse profileResponse; + return switch (identifier.identityType()) { + case ACI -> { + final AciServiceIdentifier aciServiceIdentifier = (AciServiceIdentifier) identifier; - if (maybeAccountByPni.isPresent()) { - if (maybeRequester.isEmpty()) { - throw new WebApplicationException(Response.Status.UNAUTHORIZED); - } else { - rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); + final Account targetAccount = + verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, aciServiceIdentifier); + + yield buildBaseProfileResponseForAccountIdentity(targetAccount, + isSelfProfileRequest(maybeRequester, aciServiceIdentifier), + containerRequestContext); } + case PNI -> { + final Optional maybeAccountByPni = accountsManager.getByPhoneNumberIdentifier(identifier.uuid()); - OptionalAccess.verify(maybeRequester, Optional.empty(), maybeAccountByPni); + if (maybeRequester.isEmpty()) { + throw new WebApplicationException(Response.Status.UNAUTHORIZED); + } else { + rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); + } - profileResponse = buildBaseProfileResponseForPhoneNumberIdentity(maybeAccountByPni.get()); - } else { - final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, identifier); + OptionalAccess.verify(maybeRequester, Optional.empty(), maybeAccountByPni); - profileResponse = buildBaseProfileResponseForAccountIdentity(targetAccount, - isSelfProfileRequest(maybeRequester, identifier), - containerRequestContext); - } - - return profileResponse; + assert maybeAccountByPni.isPresent(); + yield buildBaseProfileResponseForPhoneNumberIdentity(maybeAccountByPni.get()); + } + }; } @Timed @@ -363,35 +370,24 @@ public class ProfileController { private void checkFingerprintAndAdd(BatchIdentityCheckRequest.Element element, Collection responseElements, MessageDigest md) { - final Optional maybeAccount; - final boolean usePhoneNumberIdentity; - if (element.aci() != null) { - maybeAccount = accountsManager.getByAccountIdentifier(element.aci()); - usePhoneNumberIdentity = false; - } else { - final Optional maybeAciAccount = accountsManager.getByAccountIdentifier(element.uuid()); - - if (maybeAciAccount.isEmpty()) { - maybeAccount = accountsManager.getByPhoneNumberIdentifier(element.uuid()); - usePhoneNumberIdentity = true; - } else { - maybeAccount = maybeAciAccount; - usePhoneNumberIdentity = false; - } - } + final Optional maybeAccount = accountsManager.getByServiceIdentifier(element.uuid()); maybeAccount.ifPresent(account -> { if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) { return; } - final IdentityKey identityKey = - usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); + + final IdentityKey identityKey = switch (element.uuid().identityType()) { + case ACI -> account.getIdentityKey(); + case PNI -> account.getPhoneNumberIdentityKey(); + }; + md.reset(); byte[] digest = md.digest(identityKey.serialize()); byte[] fingerprint = Util.truncate(digest, 4); if (!Arrays.equals(fingerprint, element.fingerprint())) { - responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKey)); + responseElements.add(new BatchIdentityCheckResponse.Element(element.uuid(), identityKey)); } }); } @@ -454,7 +450,7 @@ public class ProfileController { getAcceptableLanguagesForRequest(containerRequestContext), account.getBadges(), isSelf), - account.getUuid()); + new AciServiceIdentifier(account.getUuid())); } private BaseProfileResponse buildBaseProfileResponseForPhoneNumberIdentity(final Account account) { @@ -463,7 +459,7 @@ public class ProfileController { false, UserCapabilities.createForAccount(account), Collections.emptyList(), - account.getPhoneNumberIdentifier()); + new PniServiceIdentifier(account.getPhoneNumberIdentifier())); } private ExpiringProfileKeyCredentialResponse getExpiringProfileKeyCredentialResponse( @@ -562,7 +558,7 @@ public class ProfileController { * * @param maybeRequester the authenticated account requesting the profile, if any * @param maybeAccessKey an anonymous access key for the target account - * @param targetUuid the ACI of the target account + * @param accountIdentifier the ACI of the target account * * @return the target account * @@ -573,7 +569,7 @@ public class ProfileController { */ private Account verifyPermissionToReceiveAccountIdentityProfile(final Optional maybeRequester, final Optional maybeAccessKey, - final UUID targetUuid) throws RateLimitExceededException { + final AciServiceIdentifier accountIdentifier) throws RateLimitExceededException { if (maybeRequester.isEmpty() && maybeAccessKey.isEmpty()) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); @@ -583,7 +579,7 @@ public class ProfileController { rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); } - final Optional maybeTargetAccount = accountsManager.getByAccountIdentifier(targetUuid); + final Optional maybeTargetAccount = accountsManager.getByAccountIdentifier(accountIdentifier.uuid()); OptionalAccess.verify(maybeRequester, maybeAccessKey, maybeTargetAccount); assert maybeTargetAccount.isPresent(); @@ -591,7 +587,7 @@ public class ProfileController { return maybeTargetAccount.get(); } - private boolean isSelfProfileRequest(final Optional maybeRequester, final UUID targetUuid) { - return maybeRequester.map(requester -> requester.getUuid().equals(targetUuid)).orElse(false); + private boolean isSelfProfileRequest(final Optional maybeRequester, final AciServiceIdentifier targetIdentifier) { + return maybeRequester.map(requester -> requester.getUuid().equals(targetIdentifier.uuid())).orElse(false); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentifierResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentifierResponse.java index dfd33fd7e..e3067d628 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentifierResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountIdentifierResponse.java @@ -4,7 +4,14 @@ */ package org.whispersystems.textsecuregcm.entities; -import javax.validation.constraints.NotNull; -import java.util.UUID; -public record AccountIdentifierResponse(@NotNull UUID uuid) {} +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import javax.validation.constraints.NotNull; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; + +public record AccountIdentifierResponse(@NotNull + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.AciServiceIdentifierDeserializer.class) + AciServiceIdentifier uuid) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java index 4991355c6..55e939ebe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java @@ -5,22 +5,14 @@ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.UUID; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; -public class AccountMismatchedDevices { - @JsonProperty - public final UUID uuid; +public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + ServiceIdentifier uuid, - @JsonProperty - public final MismatchedDevices devices; - - public String toString() { - return "AccountMismatchedDevices(" + uuid + ", " + devices + ")"; - } - - public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) { - this.uuid = uuid; - this.devices = devices; - } + MismatchedDevices devices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java index bf1282fdc..031046d8c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java @@ -5,22 +5,14 @@ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonProperty; -import java.util.UUID; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; -public class AccountStaleDevices { - @JsonProperty - public final UUID uuid; +public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + ServiceIdentifier uuid, - @JsonProperty - public final StaleDevices devices; - - public String toString() { - return "AccountStaleDevices(" + uuid + ", " + devices + ")"; - } - - public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { - this.uuid = uuid; - this.devices = devices; - } + StaleDevices devices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java index 5a94fc7cb..a2b5f3a20 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java @@ -9,11 +9,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.signal.libsignal.protocol.IdentityKey; -import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import java.util.List; -import java.util.UUID; public class BaseProfileResponse { @@ -35,7 +35,9 @@ public class BaseProfileResponse { private List badges; @JsonProperty - private UUID uuid; + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + private ServiceIdentifier uuid; public BaseProfileResponse() { } @@ -45,7 +47,7 @@ public class BaseProfileResponse { final boolean unrestrictedUnidentifiedAccess, final UserCapabilities capabilities, final List badges, - final UUID uuid) { + final ServiceIdentifier uuid) { this.identityKey = identityKey; this.unidentifiedAccess = unidentifiedAccess; @@ -75,7 +77,7 @@ public class BaseProfileResponse { return badges; } - public UUID getUuid() { + public ServiceIdentifier getUuid() { return uuid; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java index 54bf15418..9b69b787c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckRequest.java @@ -5,13 +5,15 @@ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.util.List; -import java.util.UUID; -import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.validation.constraints.Size; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.util.ExactlySize; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List elements) { @@ -20,18 +22,13 @@ public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List elements) { - public record Element(@Deprecated - @JsonInclude(JsonInclude.Include.NON_EMPTY) - @Nullable UUID aci, - - @JsonInclude(JsonInclude.Include.NON_EMPTY) - @Nullable UUID uuid, + public record Element(@JsonInclude(JsonInclude.Include.NON_EMPTY) + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + @NotNull + ServiceIdentifier uuid, @NotNull @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) IdentityKey identityKey) { - - public Element { - if (aci == null && uuid == null) { - throw new IllegalArgumentException("aci and uuid cannot both be null"); - } - - if (aci != null && uuid != null) { - throw new IllegalArgumentException("aci and uuid cannot both be non-null"); - } - } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 17496e9a7..edcafba94 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -6,14 +6,15 @@ package org.whispersystems.textsecuregcm.entities; import com.google.protobuf.ByteString; import java.util.Base64; -import java.util.UUID; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; public record IncomingMessage(int type, long destinationDeviceId, int destinationRegistrationId, String content) { - public MessageProtos.Envelope toEnvelope(final UUID destinationUuid, + public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier, @Nullable Account sourceAccount, @Nullable Long sourceDeviceId, final long timestamp, @@ -32,13 +33,13 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio envelopeBuilder.setType(envelopeType) .setTimestamp(timestamp) .setServerTimestamp(System.currentTimeMillis()) - .setDestinationUuid(destinationUuid.toString()) + .setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) .setStory(story) .setUrgent(urgent); if (sourceAccount != null && sourceDeviceId != null) { envelopeBuilder - .setSourceUuid(sourceAccount.getUuid().toString()) + .setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString()) .setSourceDevice(sourceDeviceId.intValue()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java index a3b1d136c..ef6b6eda4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java @@ -6,31 +6,15 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.annotations.VisibleForTesting; import io.swagger.v3.oas.annotations.media.Schema; import java.util.List; -public class MismatchedDevices { - - @JsonProperty - @Schema(description = "Devices present on the account but absent in the request") - public List missingDevices; - - @JsonProperty - @Schema(description = "Devices absent on the request but present in the account") - public List extraDevices; - - @VisibleForTesting - public MismatchedDevices() {} - - public String toString() { - return "MismatchedDevices(" + missingDevices + ", " + extraDevices + ")"; - } - - public MismatchedDevices(List missingDevices, List extraDevices) { - this.missingDevices = missingDevices; - this.extraDevices = extraDevices; - } +public record MismatchedDevices(@JsonProperty + @Schema(description = "Devices present on the account but absent in the request") + List missingDevices, + @JsonProperty + @Schema(description = "Devices absent on the request but present in the account") + List extraDevices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java index 4469b6fea..ba6421c61 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.entities; import static com.codahale.metrics.MetricRegistry.name; import java.util.Arrays; -import java.util.UUID; +import java.util.Objects; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; import javax.validation.constraints.Max; @@ -16,58 +16,33 @@ import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; import javax.validation.constraints.Size; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.util.Pair; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; -public class MultiRecipientMessage { +public record MultiRecipientMessage( + @NotNull @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) @Valid Recipient[] recipients, + @NotNull @Size(min = 32) byte[] commonPayload) { private static final Counter REJECT_DUPLICATE_RECIPIENT_COUNTER = Metrics.counter( name(MessageController.class, "rejectDuplicateRecipients"), "multiRecipient", "false"); - public static class Recipient { - - @NotNull - private final UUID uuid; - - @Min(1) - private final long deviceId; - - @Min(0) - @Max(65535) - private final int registrationId; - - @Size(min = 48, max = 48) - @NotNull - private final byte[] perRecipientKeyMaterial; - - public Recipient(UUID uuid, long deviceId, int registrationId, byte[] perRecipientKeyMaterial) { - this.uuid = uuid; - this.deviceId = deviceId; - this.registrationId = registrationId; - this.perRecipientKeyMaterial = perRecipientKeyMaterial; - } - - public UUID getUuid() { - return uuid; - } - - public long getDeviceId() { - return deviceId; - } - - public int getRegistrationId() { - return registrationId; - } - - public byte[] getPerRecipientKeyMaterial() { - return perRecipientKeyMaterial; - } + public record Recipient(@NotNull + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + ServiceIdentifier uuid, + @Min(1) long deviceId, + @Min(0) @Max(65535) int registrationId, + @Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) { @Override public boolean equals(final Object o) { @@ -75,60 +50,48 @@ public class MultiRecipientMessage { return true; if (o == null || getClass() != o.getClass()) return false; - Recipient recipient = (Recipient) o; - - if (deviceId != recipient.deviceId) - return false; - if (registrationId != recipient.registrationId) - return false; - if (!uuid.equals(recipient.uuid)) - return false; - return Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial); + return deviceId == recipient.deviceId && registrationId == recipient.registrationId && uuid.equals(recipient.uuid) + && Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial); } @Override public int hashCode() { - int result = uuid.hashCode(); - result = 31 * result + (int) (deviceId ^ (deviceId >>> 32)); - result = 31 * result + registrationId; + int result = Objects.hash(uuid, deviceId, registrationId); result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial); return result; } - - public String toString() { - return "Recipient(" + uuid + ", " + deviceId + ", " + registrationId + ", " + Arrays.toString(perRecipientKeyMaterial) + ")"; - } } - @NotNull - @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) - @Valid - private final Recipient[] recipients; - - @NotNull - @Size(min = 32) - private final byte[] commonPayload; - public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) { this.recipients = recipients; this.commonPayload = commonPayload; } - public Recipient[] getRecipients() { - return recipients; - } - - public byte[] getCommonPayload() { - return commonPayload; - } - @AssertTrue public boolean hasNoDuplicateRecipients() { - boolean valid = Arrays.stream(recipients).map(r -> new Pair<>(r.getUuid(), r.getDeviceId())).distinct().count() == recipients.length; + boolean valid = + Arrays.stream(recipients).map(r -> new Pair<>(r.uuid(), r.deviceId())).distinct().count() == recipients.length; if (!valid) { REJECT_DUPLICATE_RECIPIENT_COUNTER.increment(); } return valid; } + + @Override + public boolean equals(final Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + MultiRecipientMessage that = (MultiRecipientMessage) o; + return Arrays.equals(recipients, that.recipients) && Arrays.equals(commonPayload, that.commonPayload); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(recipients); + result = 31 * result + Arrays.hashCode(commonPayload); + return result; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java index 1f61f1a7d..226d243a1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java @@ -5,28 +5,50 @@ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.protobuf.ByteString; import java.util.Arrays; import java.util.Objects; import java.util.UUID; import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; -public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, - UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, - long serverTimestamp, boolean urgent, boolean story, @Nullable byte[] reportSpamToken) { +public record OutgoingMessageEntity(UUID guid, + int type, + long timestamp, + + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + @Nullable + ServiceIdentifier sourceUuid, + + int sourceDevice, + + @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + ServiceIdentifier destinationUuid, + + @Nullable UUID updatedPni, + byte[] content, + long serverTimestamp, + boolean urgent, + boolean story, + @Nullable byte[] reportSpamToken) { public MessageProtos.Envelope toEnvelope() { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() .setType(MessageProtos.Envelope.Type.forNumber(type())) .setTimestamp(timestamp()) .setServerTimestamp(serverTimestamp()) - .setDestinationUuid(destinationUuid().toString()) + .setDestinationUuid(destinationUuid().toServiceIdentifierString()) .setServerGuid(guid().toString()) .setStory(story) .setUrgent(urgent); if (sourceUuid() != null) { - builder.setSourceUuid(sourceUuid().toString()); + builder.setSourceUuid(sourceUuid().toServiceIdentifierString()); builder.setSourceDevice(sourceDevice()); } @@ -51,9 +73,9 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab UUID.fromString(envelope.getServerGuid()), envelope.getType().getNumber(), envelope.getTimestamp(), - envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, + envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null, envelope.getSourceDevice(), - envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null, + envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null, envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.getContent().toByteArray(), envelope.getServerTimestamp(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java index 62635d2ad..cc3967df9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java @@ -6,27 +6,15 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.annotations.VisibleForTesting; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter; import java.util.List; import java.util.UUID; -public class SendMultiRecipientMessageResponse { - @JsonProperty - private List uuids404; - - public SendMultiRecipientMessageResponse() { - } - - public String toString() { - return "SendMultiRecipientMessageResponse(" + uuids404 + ")"; - } - - @VisibleForTesting - public List getUUIDs404() { - return this.uuids404; - } - - public SendMultiRecipientMessageResponse(final List uuids404) { - this.uuids404 = uuids404; - } +public record SendMultiRecipientMessageResponse(@JsonSerialize(contentUsing = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) + @JsonDeserialize(contentUsing = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) + List uuids404) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java index 5cef6fa75..bed26a51f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java @@ -10,20 +10,7 @@ import io.swagger.v3.oas.annotations.media.Schema; import java.util.List; -public class StaleDevices { - - @JsonProperty - @Schema(description = "Devices that are no longer active") - private List staleDevices; - - public StaleDevices() {} - - public String toString() { - return "StaleDevices(" + staleDevices + ")"; - } - - public StaleDevices(List staleDevices) { - this.staleDevices = staleDevices; - } - +public record StaleDevices(@JsonProperty + @Schema(description = "Devices that are no longer active") + List staleDevices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java new file mode 100644 index 000000000..ba08c42ba --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.UUID; +import io.swagger.v3.oas.annotations.media.Schema; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +/** + * An identifier for an account based on the account's ACI. + * + * @param uuid the account's ACI UUID + */ +@Schema( + type = "string", + description = "An identifier for an account based on the account's ACI" +) +public record AciServiceIdentifier(UUID uuid) implements ServiceIdentifier { + + private static final IdentityType IDENTITY_TYPE = IdentityType.ACI; + + @Override + public IdentityType identityType() { + return IDENTITY_TYPE; + } + + @Override + public String toServiceIdentifierString() { + return uuid.toString(); + } + + @Override + public byte[] toCompactByteArray() { + return UUIDUtil.toBytes(uuid); + } + + @Override + public byte[] toFixedWidthByteArray() { + final ByteBuffer byteBuffer = ByteBuffer.allocate(17); + byteBuffer.put(IDENTITY_TYPE.getBytePrefix()); + byteBuffer.putLong(uuid.getMostSignificantBits()); + byteBuffer.putLong(uuid.getLeastSignificantBits()); + byteBuffer.flip(); + + return byteBuffer.array(); + } + + public static AciServiceIdentifier valueOf(final String string) { + return new AciServiceIdentifier( + UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix()) + ? string.substring(IDENTITY_TYPE.getStringPrefix().length()) : string)); + } + + public static AciServiceIdentifier fromBytes(final byte[] bytes) { + final UUID uuid; + + if (bytes.length == 17) { + if (bytes[0] != IDENTITY_TYPE.getBytePrefix()) { + throw new IllegalArgumentException("Unexpected byte array prefix: " + HexFormat.of().formatHex(new byte[] { bytes[0] })); + } + + uuid = UUIDUtil.fromBytes(Arrays.copyOfRange(bytes, 1, bytes.length)); + } else { + uuid = UUIDUtil.fromBytes(bytes); + } + + return new AciServiceIdentifier(uuid); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/IdentityType.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/IdentityType.java new file mode 100644 index 000000000..720267095 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/IdentityType.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +public enum IdentityType { + ACI((byte) 0x00, "ACI:"), + PNI((byte) 0x01, "PNI:"); + + private final byte bytePrefix; + private final String stringPrefix; + + IdentityType(final byte bytePrefix, final String stringPrefix) { + this.bytePrefix = bytePrefix; + this.stringPrefix = stringPrefix; + } + + byte getBytePrefix() { + return bytePrefix; + } + + String getStringPrefix() { + return stringPrefix; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java new file mode 100644 index 000000000..2f184fd20 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import io.swagger.v3.oas.annotations.media.Schema; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.UUID; + +/** + * An identifier for an account based on the account's phone number identifier (PNI). + * + * @param uuid the account's PNI UUID + */ +@Schema( + type = "string", + description = "An identifier for an account based on the account's phone number identifier (PNI)" +) +public record PniServiceIdentifier(UUID uuid) implements ServiceIdentifier { + + private static final IdentityType IDENTITY_TYPE = IdentityType.PNI; + + @Override + public IdentityType identityType() { + return IDENTITY_TYPE; + } + + @Override + public String toServiceIdentifierString() { + return IDENTITY_TYPE.getStringPrefix() + uuid.toString(); + } + + @Override + public byte[] toCompactByteArray() { + return toFixedWidthByteArray(); + } + + @Override + public byte[] toFixedWidthByteArray() { + final ByteBuffer byteBuffer = ByteBuffer.allocate(17); + byteBuffer.put(IDENTITY_TYPE.getBytePrefix()); + byteBuffer.putLong(uuid.getMostSignificantBits()); + byteBuffer.putLong(uuid.getLeastSignificantBits()); + byteBuffer.flip(); + + return byteBuffer.array(); + } + + public static PniServiceIdentifier valueOf(final String string) { + if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) { + throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix"); + } + + return new PniServiceIdentifier(UUID.fromString(string.substring(IDENTITY_TYPE.getStringPrefix().length()))); + } + + public static PniServiceIdentifier fromBytes(final byte[] bytes) { + if (bytes.length == 17) { + if (bytes[0] != IDENTITY_TYPE.getBytePrefix()) { + throw new IllegalArgumentException("Unexpected byte array prefix: " + HexFormat.of().formatHex(new byte[] { bytes[0] })); + } + + return new PniServiceIdentifier(UUIDUtil.fromBytes(Arrays.copyOfRange(bytes, 1, bytes.length))); + } + + throw new IllegalArgumentException("Unexpected byte array length: " + bytes.length); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java new file mode 100644 index 000000000..c012f924c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.UUID; + +/** + * A "service identifier" is a tuple of a UUID and identity type that identifies an account and identity within the + * Signal service. + */ +@Schema( + type = "string", + description = "A service identifier is a tuple of a UUID and identity type that identifies an account and identity within the Signal service.", + subTypes = {AciServiceIdentifier.class, PniServiceIdentifier.class} +) +public interface ServiceIdentifier { + + /** + * Returns the identity type of this account identifier. + * + * @return the identity type of this account identifier + */ + IdentityType identityType(); + + /** + * Returns the UUID for this account identifier. + * + * @return the UUID for this account identifier + */ + UUID uuid(); + + /** + * Returns a string representation of this account identifier in a format that clients can unambiguously resolve into + * an identity type and UUID. + * + * @return a "strongly-typed" string representation of this account identifier + */ + String toServiceIdentifierString(); + + /** + * Returns a compact binary representation of this account identifier. + * + * @return a binary representation of this account identifier + */ + byte[] toCompactByteArray(); + + /** + * Returns a fixed-width binary representation of this account identifier. + * + * @return a binary representation of this account identifier + */ + byte[] toFixedWidthByteArray(); + + static ServiceIdentifier valueOf(final String string) { + try { + return AciServiceIdentifier.valueOf(string); + } catch (final IllegalArgumentException e) { + return PniServiceIdentifier.valueOf(string); + } + } + + static ServiceIdentifier fromBytes(final byte[] bytes) { + try { + return AciServiceIdentifier.fromBytes(bytes); + } catch (final IllegalArgumentException e) { + return PniServiceIdentifier.fromBytes(bytes); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java index 2b7027b82..ff7279573 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MessageMetrics.java @@ -9,22 +9,20 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import com.vdurmont.semver4j.Semver; import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tag; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Tags; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; public final class MessageMetrics { @@ -44,16 +42,15 @@ public final class MessageMetrics { final MessageProtos.Envelope envelope) { if (envelope.hasDestinationUuid()) { try { - final UUID destinationUuid = UUID.fromString(envelope.getDestinationUuid()); - measureAccountDestinationUuidMismatches(account, destinationUuid); + measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid())); } catch (final IllegalArgumentException ignored) { logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid()); } } } - private static void measureAccountDestinationUuidMismatches(final Account account, final UUID destinationUuid) { - if (!destinationUuid.equals(account.getUuid()) && !destinationUuid.equals(account.getPhoneNumberIdentifier())) { + private static void measureAccountDestinationUuidMismatches(final Account account, final ServiceIdentifier destinationIdentifier) { + if (!account.isIdentifiedBy(destinationIdentifier)) { // In all cases, this represents a mismatch between the account’s current PNI and its PNI when the message was // sent. This is an expected case, but if this metric changes significantly, it could indicate an issue to // investigate. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index 579f98eac..74b1ec164 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -11,7 +11,6 @@ import java.io.IOException; import java.io.InputStream; import java.lang.annotation.Annotation; import java.lang.reflect.Type; -import java.util.UUID; import javax.ws.rs.BadRequestException; import javax.ws.rs.Consumes; import javax.ws.rs.WebApplicationException; @@ -21,6 +20,7 @@ import javax.ws.rs.core.NoContentException; import javax.ws.rs.ext.MessageBodyReader; import javax.ws.rs.ext.Provider; import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; @Provider @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @@ -29,7 +29,30 @@ public class MultiRecipientMessageProvider implements MessageBodyReader type, Type genericType, Annotation[] annotations, MediaType mediaType) { @@ -44,23 +67,29 @@ public class MultiRecipientMessageProvider implements MessageBodyReader MAX_RECIPIENT_COUNT) { throw new BadRequestException("Maximum recipient count exceeded"); } MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)]; for (int i = 0; i < Math.toIntExact(count); i++) { - UUID uuid = readUuid(entityStream); + ServiceIdentifier identifier = readIdentifier(entityStream, version); long deviceId = readVarint(entityStream); int registrationId = readU16(entityStream); byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); if (perRecipientKeyMaterial.length != 48) { throw new IOException("Failed to read expected number of key material bytes for a recipient"); } - recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, registrationId, perRecipientKeyMaterial); + recipients[i] = new MultiRecipientMessage.Recipient(identifier, deviceId, registrationId, perRecipientKeyMaterial); } // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than @@ -73,32 +102,15 @@ public class MultiRecipientMessageProvider implements MessageBodyReader stream.readNBytes(16); + case EXPLICIT_ID -> stream.readNBytes(17); + }; - int read = stream.readNBytes(buffer, 0, 8); - if (read != 8) { - throw new IOException("Insufficient bytes for UUID"); - } - long msb = convertNetworkByteOrderToLong(buffer); - - read = stream.readNBytes(buffer, 0, 8); - if (read != 8) { - throw new IOException("Insufficient bytes for UUID"); - } - long lsb = convertNetworkByteOrderToLong(buffer); - - return new UUID(msb, lsb); - } - - private long convertNetworkByteOrderToLong(byte[] buffer) { - long result = 0; - for (int i = 0; i < 8; i++) { - result = (result << 8) | (buffer[i] & 0xFFL); - } - return result; + return ServiceIdentifier.fromBytes(uuidBytes); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index a16290f6a..ae235e528 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -9,11 +9,12 @@ import com.codahale.metrics.InstrumentedExecutorService; import com.codahale.metrics.SharedMetricRegistries; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; -import java.util.UUID; import java.util.concurrent.ExecutorService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -40,20 +41,20 @@ public class ReceiptSender { ; } - public void sendReceipt(UUID sourceUuid, long sourceDeviceId, UUID destinationUuid, long messageId) { - if (sourceUuid.equals(destinationUuid)) { + public void sendReceipt(ServiceIdentifier sourceIdentifier, long sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) { + if (sourceIdentifier.equals(destinationIdentifier)) { return; } executor.submit(() -> { try { - accountManager.getByAccountIdentifier(destinationUuid).ifPresentOrElse( + accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse( destinationAccount -> { final Envelope.Builder message = Envelope.newBuilder() .setServerTimestamp(System.currentTimeMillis()) - .setSourceUuid(sourceUuid.toString()) + .setSourceUuid(sourceIdentifier.toServiceIdentifierString()) .setSourceDevice((int) sourceDeviceId) - .setDestinationUuid(destinationUuid.toString()) + .setDestinationUuid(destinationIdentifier.toServiceIdentifierString()) .setTimestamp(messageId) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) .setUrgent(false); @@ -68,7 +69,7 @@ public class ReceiptSender { } } }, - () -> logger.info("No longer registered: {}", destinationUuid) + () -> logger.info("No longer registered: {}", destinationIdentifier) ); } catch (final Exception e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index b4f439fc2..22103d27e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -25,6 +25,7 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; @@ -123,13 +124,17 @@ public class Account { } /** - * Tests whether this account's account identifier or phone number identifier matches the given UUID. + * Tests whether this account's account identifier or phone number identifier (depending on the given service + * identifier's identity type) matches the given service identifier. * - * @param identifier the identifier to test + * @param serviceIdentifier the identifier to test * @return {@code true} if this account's identifier or phone number identifier matches */ - public boolean isIdentifiedBy(final UUID identifier) { - return uuid.equals(identifier) || (phoneNumberIdentifier != null && phoneNumberIdentifier.equals(identifier)); + public boolean isIdentifiedBy(final ServiceIdentifier serviceIdentifier) { + return switch (serviceIdentifier.identityType()) { + case ACI -> serviceIdentifier.uuid().equals(uuid); + case PNI -> serviceIdentifier.uuid().equals(phoneNumberIdentifier); + }; } public String getNumber() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 4e2aeb293..742718d65 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -52,6 +52,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisOperation; @@ -803,6 +804,13 @@ public class AccountsManager { ); } + public Optional getByServiceIdentifier(final ServiceIdentifier serviceIdentifier) { + return switch (serviceIdentifier.identityType()) { + case ACI -> getByAccountIdentifier(serviceIdentifier.uuid()); + case PNI -> getByPhoneNumberIdentifier(serviceIdentifier.uuid()); + }; + } + public Optional getByAccountIdentifier(final UUID uuid) { return checkRedisThenAccounts( getByUuidTimer, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 719f9023a..d16ab858a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; @@ -136,9 +137,9 @@ public class ChangeNumberManager { .setType(Envelope.Type.forNumber(message.type())) .setTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp) - .setDestinationUuid(sourceAndDestinationAccount.getUuid().toString()) + .setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) .setContent(ByteString.copyFrom(contents.get())) - .setSourceUuid(sourceAndDestinationAccount.getUuid().toString()) + .setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString()) .setSourceDevice((int) Device.MASTER_ID) .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) .setUrgent(true) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ServiceIdentifierAdapter.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ServiceIdentifierAdapter.java new file mode 100644 index 000000000..239fa002d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ServiceIdentifierAdapter.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import java.io.IOException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; + +public class ServiceIdentifierAdapter { + + public static class ServiceIdentifierSerializer extends JsonSerializer { + + @Override + public void serialize(final ServiceIdentifier identifier, final JsonGenerator jsonGenerator, final SerializerProvider serializers) + throws IOException { + + jsonGenerator.writeString(identifier.toServiceIdentifierString()); + } + } + + public static class AciServiceIdentifierDeserializer extends JsonDeserializer { + + @Override + public AciServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context) + throws IOException { + + return AciServiceIdentifier.valueOf(parser.getValueAsString()); + } + } + + public static class PniServiceIdentifierDeserializer extends JsonDeserializer { + + @Override + public PniServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context) + throws IOException { + + return PniServiceIdentifier.valueOf(parser.getValueAsString()); + } + } + + public static class ServiceIdentifierDeserializer extends JsonDeserializer { + + @Override + public ServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context) + throws IOException { + + return ServiceIdentifier.valueOf(parser.getValueAsString()); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 6b6ba4155..facba2146 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -40,6 +40,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -265,8 +267,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } try { - receiptSender.sendReceipt(UUID.fromString(message.getDestinationUuid()), - auth.getAuthenticatedDevice().getId(), UUID.fromString(message.getSourceUuid()), + receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()), + auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()), message.getTimestamp()); } catch (IllegalArgumentException e) { logger.error("Could not parse UUID: {}", message.getSourceUuid()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index d8651752a..009e7d23f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -70,6 +70,8 @@ import org.whispersystems.textsecuregcm.entities.RegistrationLock; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -869,10 +871,9 @@ class AccountControllerTest { final UUID accountIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID(); - when(accountsManager.getByAccountIdentifier(any())).thenReturn(Optional.empty()); - when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); - when(accountsManager.getByPhoneNumberIdentifier(any())).thenReturn(Optional.empty()); - when(accountsManager.getByPhoneNumberIdentifier(phoneNumberIdentifier)).thenReturn(Optional.of(account)); + when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(accountIdentifier))).thenReturn(Optional.of(account)); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(phoneNumberIdentifier))).thenReturn(Optional.of(account)); when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(mock(RateLimiter.class)); @@ -884,7 +885,7 @@ class AccountControllerTest { .getStatus()).isEqualTo(200); assertThat(resources.getJerseyTest() - .target(String.format("/v1/accounts/account/%s", phoneNumberIdentifier)) + .target(String.format("/v1/accounts/account/PNI:%s", phoneNumberIdentifier)) .request() .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() @@ -954,7 +955,7 @@ class AccountControllerTest { .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get(); assertThat(response.getStatus()).isEqualTo(200); - assertThat(response.readEntity(AccountIdentifierResponse.class).uuid()).isEqualTo(uuid); + assertThat(response.readEntity(AccountIdentifierResponse.class).uuid().uuid()).isEqualTo(uuid); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index fa1c9db76..5f36dde94 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -57,6 +57,8 @@ import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; @@ -77,7 +79,6 @@ class KeysControllerTest { private static final UUID EXISTS_UUID = UUID.randomUUID(); private static final UUID EXISTS_PNI = UUID.randomUUID(); - private static final String NOT_EXISTS_NUMBER = "+14152222220"; private static final UUID NOT_EXISTS_UUID = UUID.randomUUID(); private static final int SAMPLE_REGISTRATION_ID = 999; @@ -212,12 +213,10 @@ class KeysControllerTest { when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER); when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes())); - when(accounts.getByE164(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount)); - when(accounts.getByAccountIdentifier(EXISTS_UUID)).thenReturn(Optional.of(existsAccount)); - when(accounts.getByPhoneNumberIdentifier(EXISTS_PNI)).thenReturn(Optional.of(existsAccount)); + when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty()); - when(accounts.getByE164(NOT_EXISTS_NUMBER)).thenReturn(Optional.empty()); - when(accounts.getByAccountIdentifier(NOT_EXISTS_UUID)).thenReturn(Optional.empty()); + when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount)); + when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount)); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); @@ -384,7 +383,7 @@ class KeysControllerTest { @Test void validSingleRequestByPhoneNumberIdentifierTestV2() { PreKeyResponse result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/1", EXISTS_PNI)) + .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(PreKeyResponse.class); @@ -404,7 +403,7 @@ class KeysControllerTest { @Test void validSingleRequestPqByPhoneNumberIdentifierTestV2() { PreKeyResponse result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/1", EXISTS_PNI)) + .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI)) .queryParam("pq", "true") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) @@ -428,7 +427,7 @@ class KeysControllerTest { when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); PreKeyResponse result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/1", EXISTS_PNI)) + .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(PreKeyResponse.class); @@ -451,7 +450,7 @@ class KeysControllerTest { doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString()); Response result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/*", EXISTS_PNI)) + .target(String.format("/v2/keys/PNI:%s/*", EXISTS_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 42e571b17..5f60aa6d5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -29,6 +30,7 @@ import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; +import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.collect.ImmutableSet; import com.google.protobuf.ByteString; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; @@ -42,11 +44,13 @@ import java.nio.ByteOrder; import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Optional; import java.util.Random; +import java.util.Set; import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -69,6 +73,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -78,6 +83,8 @@ import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; +import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; +import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; @@ -91,11 +98,15 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; @@ -111,6 +122,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.websocket.Stories; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; @@ -191,11 +203,11 @@ class MessageControllerTest { Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); - when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); - when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); - when(accountsManager.getByAccountIdentifier(eq(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); - when(accountsManager.getByPhoneNumberIdentifier(MULTI_DEVICE_PNI)).thenReturn(Optional.of(multiDeviceAccount)); - when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); final DynamicDeliveryLatencyConfiguration deliveryLatencyConfiguration = mock(DynamicDeliveryLatencyConfiguration.class); when(deliveryLatencyConfiguration.instrumentedVersions()).thenReturn(Collections.emptyMap()); @@ -310,7 +322,7 @@ class MessageControllerTest { void testSingleDeviceCurrentByPni() throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI)) + .target(String.format("/v1/messages/PNI:%s", SINGLE_DEVICE_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), @@ -471,7 +483,7 @@ class MessageControllerTest { void testMultiDeviceByPni() throws Exception { Response response = resources.getJerseyTest() - .target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI)) + .target(String.format("/v1/messages/PNI:%s", MULTI_DEVICE_PNI)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"), @@ -543,14 +555,14 @@ class MessageControllerTest { OutgoingMessageEntity first = messages.get(0); assertEquals(first.timestamp(), timestampOne); assertEquals(first.guid(), messageGuidOne); - assertEquals(first.sourceUuid(), sourceUuid); + assertEquals(first.sourceUuid().uuid(), sourceUuid); assertEquals(updatedPniOne, first.updatedPni()); if (receiveStories) { OutgoingMessageEntity second = messages.get(1); assertEquals(second.timestamp(), timestampTwo); assertEquals(second.guid(), messageGuidTwo); - assertEquals(second.sourceUuid(), sourceUuid); + assertEquals(second.sourceUuid().uuid(), sourceUuid); assertNull(second.updatedPni()); } @@ -623,8 +635,8 @@ class MessageControllerTest { .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); - verify(receiptSender).sendReceipt(eq(AuthHelper.VALID_UUID), eq(1L), - eq(sourceUuid), eq(timestamp)); + verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L), + eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp)); response = resources.getJerseyTest() .target(String.format("/v1/messages/uuid/%s", uuid2)) @@ -920,28 +932,32 @@ class MessageControllerTest { } while (x != 0); } - private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception { - long msb = r.getUuid().getMostSignificantBits(); - long lsb = r.getUuid().getLeastSignificantBits(); - bb.putLong(msb); // uuid (first 8 bytes) - bb.putLong(lsb); // uuid (last 8 bytes) - writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes) - bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes) - bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes) + private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) { + if (useExplicitIdentifier) { + bb.put(r.uuid().toFixedWidthByteArray()); + } else { + bb.put(UUIDUtil.toBytes(r.uuid().uuid())); + } + + writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes) + bb.putShort((short) r.registrationId()); // registration id (2 bytes) + bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) } - private static InputStream initializeMultiPayload(List recipients, byte[] buffer) throws Exception { + private static InputStream initializeMultiPayload(List recipients, byte[] buffer, final boolean explicitIdentifiers) { // initialize a binary payload according to our wire format ByteBuffer bb = ByteBuffer.wrap(buffer); bb.order(ByteOrder.BIG_ENDIAN); // first write the header - bb.put(MultiRecipientMessageProvider.VERSION); // version byte + bb.put(explicitIdentifiers + ? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER + : MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte bb.put((byte)recipients.size()); // count varint Iterator it = recipients.iterator(); while (it.hasNext()) { - writeMultiPayloadRecipient(bb, it.next()); + writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers); } // now write the actual message body (empty for now) @@ -953,22 +969,22 @@ class MessageControllerTest { @ParameterizedTest @MethodSource - void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception { + void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception { final List recipients; if (recipientUUID == MULTI_DEVICE_UUID) { recipients = List.of( - new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) ); } else { - recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); } // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; //InputStream stream = initializeMultiPayload(recipientUUID, buffer); - InputStream stream = initializeMultiPayload(recipients, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier); // set up the entity to use in our PUT request Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); @@ -1058,31 +1074,48 @@ class MessageControllerTest { // Arguments here are: recipient-UUID, is-authorized?, is-story? private static Stream testMultiRecipientMessage() { return Stream.of( - Arguments.of(MULTI_DEVICE_UUID, false, true, true), - Arguments.of(MULTI_DEVICE_UUID, false, false, true), - Arguments.of(SINGLE_DEVICE_UUID, false, true, true), - Arguments.of(SINGLE_DEVICE_UUID, false, false, true), - Arguments.of(MULTI_DEVICE_UUID, true, true, true), - Arguments.of(MULTI_DEVICE_UUID, true, false, true), - Arguments.of(SINGLE_DEVICE_UUID, true, true, true), - Arguments.of(SINGLE_DEVICE_UUID, true, false, true), - Arguments.of(MULTI_DEVICE_UUID, false, true, false), - Arguments.of(MULTI_DEVICE_UUID, false, false, false), - Arguments.of(SINGLE_DEVICE_UUID, false, true, false), - Arguments.of(SINGLE_DEVICE_UUID, false, false, false), - Arguments.of(MULTI_DEVICE_UUID, true, true, false), - Arguments.of(MULTI_DEVICE_UUID, true, false, false), - Arguments.of(SINGLE_DEVICE_UUID, true, true, false), - Arguments.of(SINGLE_DEVICE_UUID, true, false, false) + Arguments.of(MULTI_DEVICE_UUID, false, true, true, false), + Arguments.of(MULTI_DEVICE_UUID, false, false, true, false), + Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false), + Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false), + Arguments.of(MULTI_DEVICE_UUID, true, true, true, false), + Arguments.of(MULTI_DEVICE_UUID, true, false, true, false), + Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false), + Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false), + Arguments.of(MULTI_DEVICE_UUID, false, true, false, false), + Arguments.of(MULTI_DEVICE_UUID, false, false, false, false), + Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false), + Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false), + Arguments.of(MULTI_DEVICE_UUID, true, true, false, false), + Arguments.of(MULTI_DEVICE_UUID, true, false, false, false), + Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false), + Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false), + Arguments.of(MULTI_DEVICE_UUID, false, true, true, true), + Arguments.of(MULTI_DEVICE_UUID, false, false, true, true), + Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true), + Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true), + Arguments.of(MULTI_DEVICE_UUID, true, true, true, true), + Arguments.of(MULTI_DEVICE_UUID, true, false, true, true), + Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true), + Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true), + Arguments.of(MULTI_DEVICE_UUID, false, true, false, true), + Arguments.of(MULTI_DEVICE_UUID, false, false, false, true), + Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true), + Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true), + Arguments.of(MULTI_DEVICE_UUID, true, true, false, true), + Arguments.of(MULTI_DEVICE_UUID, true, false, false, true), + Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true), + Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true) ); } - @Test - void testMultiRecipientRedisBombProtection() throws Exception { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { final List recipients = List.of( - new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); Response response = resources .getJerseyTest() @@ -1094,7 +1127,7 @@ class MessageControllerTest { .request() .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) - .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048]), MultiRecipientMessageProvider.MEDIA_TYPE)); + .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE)); checkBadMultiRecipientResponse(response, 422); } @@ -1118,22 +1151,22 @@ class MessageControllerTest { @ParameterizedTest @MethodSource - void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception { + void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) { final Recipient r1; if (known) { - r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); + r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); } else { - r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]); + r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]); } - Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); - Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); + Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); + Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); List recipients = List.of(r1, r2, r3); byte[] buffer = new byte[2048]; - InputStream stream = initializeMultiPayload(recipients, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, useExplicitIdentifier); // set up the entity to use in our PUT request Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); @@ -1167,10 +1200,170 @@ class MessageControllerTest { private static Stream testSendMultiRecipientMessageToUnknownAccounts() { return Stream.of( - Arguments.of(true, true), - Arguments.of(true, false), - Arguments.of(false, true), - Arguments.of(false, false)); + Arguments.of(true, true, false), + Arguments.of(true, false, false), + Arguments.of(false, true, false), + Arguments.of(false, false, false), + + Arguments.of(true, true, true), + Arguments.of(true, false, true), + Arguments.of(false, true, true), + Arguments.of(false, false, true) + ); + } + + @ParameterizedTest + @MethodSource + void sendMultiRecipientMessageMismatchedDevices(final ServiceIdentifier serviceIdentifier) + throws JsonProcessingException { + + final List recipients = List.of( + new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48])); + + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + // InputStream stream = initializeMultiPayload(recipientUUID, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, true); + + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // start building the request + final Invocation.Builder invocationBuilder = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", false) + .queryParam("ts", System.currentTimeMillis()) + .queryParam("story", false) + .queryParam("urgent", true) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + + // make the PUT request + final Response response = invocationBuilder.put(entity); + + assertEquals(409, response.getStatus()); + + final List mismatchedDevices = + SystemMapper.jsonMapper().readValue(response.readEntity(String.class), + SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class)); + + assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier, + new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))), + mismatchedDevices); + } + + private static Stream sendMultiRecipientMessageMismatchedDevices() { + return Stream.of( + Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), + Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); + } + + @ParameterizedTest + @MethodSource + void sendMultiRecipientMessageStaleDevices(final ServiceIdentifier serviceIdentifier) throws JsonProcessingException { + final List recipients = List.of( + new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1 + 1, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2 + 1, new byte[48])); + + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + // InputStream stream = initializeMultiPayload(recipientUUID, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, true); + + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // start building the request + final Invocation.Builder invocationBuilder = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", false) + .queryParam("ts", System.currentTimeMillis()) + .queryParam("story", false) + .queryParam("urgent", true) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + + // make the PUT request + final Response response = invocationBuilder.put(entity); + + assertEquals(410, response.getStatus()); + + final List staleDevices = + SystemMapper.jsonMapper().readValue(response.readEntity(String.class), + SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountStaleDevices.class)); + + assertEquals(1, staleDevices.size()); + assertEquals(serviceIdentifier, staleDevices.get(0).uuid()); + assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices())); + } + + private static Stream sendMultiRecipientMessageStaleDevices() { + return Stream.of( + Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), + Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); + } + + @ParameterizedTest + @MethodSource + void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier) + throws NotPushRegisteredException, InterruptedException { + + when(multiRecipientMessageExecutor.invokeAll(any())) + .thenAnswer(answer -> { + final List tasks = answer.getArgument(0, List.class); + tasks.forEach(c -> { + try { + c.call(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + return null; + }); + + final List recipients = List.of( + new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])); + + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + // InputStream stream = initializeMultiPayload(recipientUUID, buffer); + InputStream stream = initializeMultiPayload(recipients, buffer, true); + + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // start building the request + final Invocation.Builder invocationBuilder = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", false) + .queryParam("ts", System.currentTimeMillis()) + .queryParam("story", true) + .queryParam("urgent", true) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + + doThrow(NotPushRegisteredException.class) + .when(messageSender).sendMessage(any(), any(), any(), anyBoolean()); + + // make the PUT request + final SendMultiRecipientMessageResponse response = invocationBuilder.put(entity, SendMultiRecipientMessageResponse.class); + + assertEquals(List.of(serviceIdentifier), response.uuids404()); + } + + private static Stream sendMultiRecipientMessage404() { + return Stream.of( + Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), + Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); } private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { @@ -1185,7 +1378,7 @@ class MessageControllerTest { verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture()); assert (captor.getValue().size() == expectedCount); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); - assert (smrmr.getUUIDs404().isEmpty()); + assert (smrmr.uuids404().isEmpty()); } private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, @@ -1226,7 +1419,7 @@ class MessageControllerTest { int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 byte[] perKeyBytes = new byte[48]; // size=48, non-null rng.nextBytes(perKeyBytes); - return new Recipient(u1, d1, dr1, perKeyBytes); + return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes); } private static void roundTripVarint(long expected, byte [] bytes) throws Exception { @@ -1258,8 +1451,9 @@ class MessageControllerTest { } } - @Test - void testMultiPayloadRoundtrip() throws Exception { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception { Random rng = new java.util.Random(); List expected = new LinkedList<>(); for(int i = 0; i < 100; i++) { @@ -1267,11 +1461,11 @@ class MessageControllerTest { } byte[] buffer = new byte[100 + expected.size() * 100]; - InputStream entityStream = initializeMultiPayload(expected, buffer); + InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers); MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider(); // the provider ignores the headers, java reflection, etc. so we don't use those here. MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream); - List got = Arrays.asList(res.getRecipients()); + List got = Arrays.asList(res.recipients()); assertEquals(expected, got); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index 1faa26971..046e94f18 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -38,7 +38,6 @@ import java.util.Collections; import java.util.HexFormat; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.Executors; @@ -90,6 +89,9 @@ import org.whispersystems.textsecuregcm.entities.CreateProfileRequest; import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialProfileResponse; import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; @@ -202,6 +204,7 @@ class ProfileControllerTest { Account capabilitiesAccount = mock(Account.class); + when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(capabilitiesAccount.getIdentityKey()).thenReturn(ACCOUNT_IDENTITY_KEY); when(capabilitiesAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY); when(capabilitiesAccount.isEnabled()).thenReturn(true); @@ -209,20 +212,23 @@ class ProfileControllerTest { when(capabilitiesAccount.isAnnouncementGroupSupported()).thenReturn(true); when(capabilitiesAccount.isChangeNumberSupported()).thenReturn(true); + when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty()); + when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByPhoneNumberIdentifier(AuthHelper.VALID_PNI_TWO)).thenReturn(Optional.of(profileAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO))).thenReturn(Optional.of(profileAccount)); + when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO))).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByUsernameHash(USERNAME_HASH)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(capabilitiesAccount)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount)); + when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount)); when(profilesManager.get(eq(AuthHelper.VALID_UUID), eq("someversion"))).thenReturn(Optional.empty()); when(profilesManager.get(eq(AuthHelper.VALID_UUID_TWO), eq("validversion"))).thenReturn(Optional.of(new VersionedProfile( "validversion", "validname", "profiles/validavatar", "emoji", "about", null, "validcommitmnet".getBytes()))); - when(accountsManager.getByAccountIdentifier(AuthHelper.INVALID_UUID)).thenReturn(Optional.empty()); - clearInvocations(rateLimiter); clearInvocations(accountsManager); clearInvocations(usernameRateLimiter); @@ -308,14 +314,14 @@ class ProfileControllerTest { @Test void testProfileGetByPni() throws RateLimitExceededException { final BaseProfileResponse profile = resources.getJerseyTest() - .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) + .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(BaseProfileResponse.class); assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); assertThat(profile.getBadges()).isEmpty(); - assertThat(profile.getUuid()).isEqualTo(AuthHelper.VALID_PNI_TWO); + assertThat(profile.getUuid()).isEqualTo(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO)); assertThat(profile.getCapabilities()).isNotNull(); assertThat(profile.isUnrestrictedUnidentifiedAccess()).isFalse(); assertThat(profile.getUnidentifiedAccess()).isNull(); @@ -342,7 +348,7 @@ class ProfileControllerTest { @Test void testProfileGetByPniUnidentified() throws RateLimitExceededException { final Response response = resources.getJerseyTest() - .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) + .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO) .request() .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .get(); @@ -836,7 +842,7 @@ class ProfileControllerTest { assertThat(profile.getAboutEmoji()).isEqualTo("emoji"); assertThat(profile.getAvatar()).isEqualTo("profiles/validavatar"); assertThat(profile.getBaseProfileResponse().getCapabilities().gv1Migration()).isTrue(); - assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID_TWO); + assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO)); assertThat(profile.getBaseProfileResponse().getBadges()).hasSize(1).element(0).has(new Condition<>( badge -> "Test Badge".equals(badge.getName()), "has badge with expected name")); @@ -927,7 +933,9 @@ class ProfileControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(ExpiringProfileKeyCredentialProfileResponse.class); - assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID); + assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()) + .isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID)); + assertThat(profile.getCredential()).isNull(); verify(zkProfileOperations, never()).issueExpiringProfileKeyCredential(any(), any(), any(), any()); @@ -1092,7 +1100,8 @@ class ProfileControllerTest { .headers(authHeaders) .get(ExpiringProfileKeyCredentialProfileResponse.class); - assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID); + assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()) + .isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID)); assertThat(profile.getCredential()).isEqualTo(credentialResponse); verify(zkProfileOperations).issueExpiringProfileKeyCredential(credentialRequest, AuthHelper.VALID_UUID, profileKeyCommitment, expiration); @@ -1154,13 +1163,13 @@ class ProfileControllerTest { void testBatchIdentityCheck() { try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID), convertKeyToFingerprint(ACCOUNT_IDENTITY_KEY)), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, + new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO), convertKeyToFingerprint(ACCOUNT_TWO_IDENTITY_KEY)), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID), convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)) ))))) { assertThat(response).isNotNull(); @@ -1170,17 +1179,14 @@ class ProfileControllerTest { assertThat(identityCheckResponse.elements()).isNotNull().isEmpty(); } - final Condition isAnExpectedUuid = new Condition<>(element -> { - if (AuthHelper.VALID_UUID.equals(element.aci())) { - return Objects.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); - } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { - return Objects.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); - } else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) { - return Objects.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey()); - } else { - return false; - } - }, "is an expected UUID with the correct identity key"); + final Map expectedIdentityKeys = Map.of( + new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY, + new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, + new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO), ACCOUNT_TWO_IDENTITY_KEY); + + final Condition isAnExpectedUuid = + new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())), + "is an expected UUID with the correct identity key"); final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -1189,13 +1195,13 @@ class ProfileControllerTest { try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID), convertKeyToFingerprint(validAciIdentityKey)), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, + new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), convertKeyToFingerprint(secondValidPniIdentityKey)), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO), convertKeyToFingerprint(secondValidAciIdentityKey)), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID), convertKeyToFingerprint(invalidAciIdentityKey)) ))))) { assertThat(response).isNotNull(); @@ -1209,13 +1215,13 @@ class ProfileControllerTest { } final List largeElementList = new ArrayList<>(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint(validAciIdentityKey)), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint(secondValidPniIdentityKey)), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint(invalidAciIdentityKey)))); + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID), convertKeyToFingerprint(validAciIdentityKey)), + new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), convertKeyToFingerprint(secondValidPniIdentityKey)), + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID), convertKeyToFingerprint(invalidAciIdentityKey)))); for (int i = 0; i < 900; i++) { largeElementList.add( - new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); + new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(UUID.randomUUID()), convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); } try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() @@ -1233,27 +1239,25 @@ class ProfileControllerTest { @Test void testBatchIdentityCheckDeserialization() throws Exception { - final Condition isAnExpectedUuid = new Condition<>(element -> { - if (AuthHelper.VALID_UUID.equals(element.aci())) { - return ACCOUNT_IDENTITY_KEY.equals(element.identityKey()); - } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { - return ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY.equals(element.identityKey()); - } else { - return false; - } - }, "is an expected UUID with the correct identity key"); + final Map expectedIdentityKeys = Map.of( + new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY, + new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); + + final Condition isAnExpectedUuid = + new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())), + "is an expected UUID with the correct identity key"); // null properties are ok to omit final String json = String.format(""" { "elements": [ - { "aci": "%s", "fingerprint": "%s" }, { "uuid": "%s", "fingerprint": "%s" }, - { "aci": "%s", "fingerprint": "%s" } + { "uuid": "%s", "fingerprint": "%s" }, + { "uuid": "%s", "fingerprint": "%s" } ] } """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), - AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), + "PNI:" + AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() @@ -1277,50 +1281,34 @@ class ProfileControllerTest { @ParameterizedTest @MethodSource - void testBatchIdentityCheckDeserializationBadRequest(final String json) { + void testBatchIdentityCheckDeserializationBadRequest(final String json, final int expectedStatus) { try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.entity(json, "application/json"))) { assertThat(response).isNotNull(); - assertThat(response.getStatus()).isEqualTo(400); + assertThat(response.getStatus()).isEqualTo(expectedStatus); } } static Stream testBatchIdentityCheckDeserializationBadRequest() { return Stream.of( Arguments.of( // aci and uuid cannot both be null - """ - { - "elements": [ - { "aci": null, "uuid": null, "fingerprint": "%s" } - ] - } - """), - Arguments.of( // an empty string is also invalid - """ - { - "elements": [ - { "aci": "", "uuid": null, "fingerprint": "%s" } - ] - } - """ - ), - Arguments.of( // as is a blank string - """ - { - "elements": [ - { "aci": null, "uuid": " ", "fingerprint": "%s" } - ] - } - """), - Arguments.of( // aci and uuid cannot both be non-null String.format(""" - { - "elements": [ - { "aci": "%s", "uuid": "%s", "fingerprint": "%s" } - ] - } - """, AuthHelper.VALID_UUID, AuthHelper.VALID_PNI, - Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))))) + { + "elements": [ + { "uuid": null, "fingerprint": "%s" } + ] + } + """, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))), + 422), + Arguments.of( // a blank string is invalid + String.format(""" + { + "elements": [ + { "uuid": " ", "fingerprint": "%s" } + ] + } + """, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))), + 400) ); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index 1ebb97661..5c62c7e19 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -9,20 +9,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.Random; import java.util.UUID; -import java.util.stream.Stream; import javax.annotation.Nullable; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; +import org.junitpioneer.jupiter.cartesian.ArgumentSets; +import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; class OutgoingMessageEntityTest { - @ParameterizedTest - @MethodSource - void roundTripThroughEnvelope(@Nullable final UUID sourceUuid, @Nullable final UUID updatedPni) { + @CartesianTest + @CartesianTest.MethodFactory("roundTripThroughEnvelope") + void roundTripThroughEnvelope(@Nullable final ServiceIdentifier sourceIdentifier, + final ServiceIdentifier destinationIdentifier, + @Nullable final UUID updatedPni) { + final byte[] messageContent = new byte[16]; new Random().nextBytes(messageContent); @@ -35,9 +39,9 @@ class OutgoingMessageEntityTest { UUID.randomUUID(), MessageProtos.Envelope.Type.CIPHERTEXT_VALUE, messageTimestamp, - UUID.randomUUID(), - sourceUuid != null ? (int) Device.MASTER_ID : 0, - UUID.randomUUID(), + sourceIdentifier, + sourceIdentifier != null ? (int) Device.MASTER_ID : 0, + destinationIdentifier, updatedPni, messageContent, serverTimestamp, @@ -48,11 +52,14 @@ class OutgoingMessageEntityTest { assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); } - private static Stream roundTripThroughEnvelope() { - return Stream.of( - Arguments.of(UUID.randomUUID(), UUID.randomUUID()), - Arguments.of(UUID.randomUUID(), null), - Arguments.of(null, UUID.randomUUID())); + @SuppressWarnings("unused") + static ArgumentSets roundTripThroughEnvelope() { + return ArgumentSets.argumentsForFirstParameter(new AciServiceIdentifier(UUID.randomUUID()), + new PniServiceIdentifier(UUID.randomUUID()), + null) + .argumentsForNextParameter(new AciServiceIdentifier(UUID.randomUUID()), + new PniServiceIdentifier(UUID.randomUUID())) + .argumentsForNextParameter(UUID.randomUUID(), null); } @Test @@ -71,7 +78,7 @@ class OutgoingMessageEntityTest { IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA"); MessageProtos.Envelope baseEnvelope = message.toEnvelope( - UUID.randomUUID(), + new AciServiceIdentifier(UUID.randomUUID()), account, 123L, System.currentTimeMillis(), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifierTest.java new file mode 100644 index 000000000..6cefcfe11 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifierTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class AciServiceIdentifierTest { + + @Test + void identityType() { + assertEquals(IdentityType.ACI, new AciServiceIdentifier(UUID.randomUUID()).identityType()); + } + + @Test + void toServiceIdentifierString() { + final UUID uuid = UUID.randomUUID(); + + assertEquals(uuid.toString(), new AciServiceIdentifier(uuid).toServiceIdentifierString()); + } + + @Test + void toCompactByteArray() { + final UUID uuid = UUID.randomUUID(); + + assertArrayEquals(UUIDUtil.toBytes(uuid), new AciServiceIdentifier(uuid).toCompactByteArray()); + } + + @Test + void toFixedWidthByteArray() { + final UUID uuid = UUID.randomUUID(); + + final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17); + expectedBytesBuffer.put((byte) 0x00); + expectedBytesBuffer.putLong(uuid.getMostSignificantBits()); + expectedBytesBuffer.putLong(uuid.getLeastSignificantBits()); + expectedBytesBuffer.flip(); + + assertArrayEquals(expectedBytesBuffer.array(), new AciServiceIdentifier(uuid).toFixedWidthByteArray()); + } + + @Test + void valueOf() { + final UUID uuid = UUID.randomUUID(); + + assertEquals(uuid, AciServiceIdentifier.valueOf(uuid.toString()).uuid()); + assertEquals(uuid, AciServiceIdentifier.valueOf("ACI:" + uuid).uuid()); + assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("Not a valid UUID")); + assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("PNI:" + uuid)); + } + + @Test + void fromBytes() { + final UUID uuid = UUID.randomUUID(); + + assertEquals(uuid, AciServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid)).uuid()); + + final byte[] prefixedBytes = new byte[17]; + prefixedBytes[0] = 0x00; + System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16); + + assertEquals(uuid, AciServiceIdentifier.fromBytes(prefixedBytes).uuid()); + + prefixedBytes[0] = 0x01; + + assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.fromBytes(prefixedBytes)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifierTest.java new file mode 100644 index 000000000..c53de62fc --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifierTest.java @@ -0,0 +1,71 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class PniServiceIdentifierTest { + + @Test + void identityType() { + assertEquals(IdentityType.PNI, new PniServiceIdentifier(UUID.randomUUID()).identityType()); + } + + @Test + void toServiceIdentifierString() { + final UUID uuid = UUID.randomUUID(); + + assertEquals("PNI:" + uuid, new PniServiceIdentifier(uuid).toServiceIdentifierString()); + } + + @Test + void toByteArray() { + final UUID uuid = UUID.randomUUID(); + + final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17); + expectedBytesBuffer.put((byte) 0x01); + expectedBytesBuffer.putLong(uuid.getMostSignificantBits()); + expectedBytesBuffer.putLong(uuid.getLeastSignificantBits()); + expectedBytesBuffer.flip(); + + assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toCompactByteArray()); + assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toFixedWidthByteArray()); + } + + @Test + void valueOf() { + final UUID uuid = UUID.randomUUID(); + + assertEquals(uuid, PniServiceIdentifier.valueOf("PNI:" + uuid).uuid()); + assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf(uuid.toString())); + assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("Not a valid UUID")); + assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("ACI:" + uuid)); + } + + @Test + void fromBytes() { + final UUID uuid = UUID.randomUUID(); + + assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid))); + + final byte[] prefixedBytes = new byte[17]; + prefixedBytes[0] = 0x00; + System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16); + + assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(prefixedBytes)); + + prefixedBytes[0] = 0x01; + + assertEquals(uuid, PniServiceIdentifier.fromBytes(prefixedBytes).uuid()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifierTest.java new file mode 100644 index 000000000..b06bca9e1 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifierTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.identity; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.UUID; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class ServiceIdentifierTest { + + @ParameterizedTest + @MethodSource + void valueOf(final String identifierString, final IdentityType expectedIdentityType, final UUID expectedUuid) { + final ServiceIdentifier serviceIdentifier = ServiceIdentifier.valueOf(identifierString); + + assertEquals(expectedIdentityType, serviceIdentifier.identityType()); + assertEquals(expectedUuid, serviceIdentifier.uuid()); + } + + private static Stream valueOf() { + final UUID uuid = UUID.randomUUID(); + + return Stream.of( + Arguments.of(uuid.toString(), IdentityType.ACI, uuid), + Arguments.of("ACI:" + uuid, IdentityType.ACI, uuid), + Arguments.of("PNI:" + uuid, IdentityType.PNI, uuid)); + } + + @ParameterizedTest + @ValueSource(strings = {"Not a valid UUID", "BAD:a9edc243-3e93-45d4-95c6-e3a84cd4a254"}) + void valueOfIllegalArgument(final String identifierString) { + assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.valueOf(identifierString)); + } + + @ParameterizedTest + @MethodSource + void fromBytes(final byte[] bytes, final IdentityType expectedIdentityType, final UUID expectedUuid) { + final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromBytes(bytes); + + assertEquals(expectedIdentityType, serviceIdentifier.identityType()); + assertEquals(expectedUuid, serviceIdentifier.uuid()); + } + + private static Stream fromBytes() { + final UUID uuid = UUID.randomUUID(); + + final byte[] aciPrefixedBytes = new byte[17]; + aciPrefixedBytes[0] = 0x00; + System.arraycopy(UUIDUtil.toBytes(uuid), 0, aciPrefixedBytes, 1, 16); + + final byte[] pniPrefixedBytes = new byte[17]; + pniPrefixedBytes[0] = 0x01; + System.arraycopy(UUIDUtil.toBytes(uuid), 0, pniPrefixedBytes, 1, 16); + + return Stream.of( + Arguments.of(UUIDUtil.toBytes(uuid), IdentityType.ACI, uuid), + Arguments.of(aciPrefixedBytes, IdentityType.ACI, uuid), + Arguments.of(pniPrefixedBytes, IdentityType.PNI, uuid)); + } + + @ParameterizedTest + @MethodSource + void fromBytesIllegalArgument(final byte[] bytes) { + assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.fromBytes(bytes)); + } + + private static Stream fromBytesIllegalArgument() { + final byte[] invalidPrefixBytes = new byte[17]; + invalidPrefixBytes[0] = (byte) 0xff; + + return Stream.of( + Arguments.of(new byte[0]), + Arguments.of(new byte[15]), + Arguments.of(new byte[18]), + Arguments.of(invalidPrefixBytes)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java index 32c1b5fc2..c413d2cb9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.metrics; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -21,6 +22,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; class MessageMetricsTest { @@ -35,6 +39,9 @@ class MessageMetricsTest { void setup() { when(account.getUuid()).thenReturn(aci); when(account.getPhoneNumberIdentifier()).thenReturn(pni); + when(account.isIdentifiedBy(any())).thenReturn(false); + when(account.isIdentifiedBy(new AciServiceIdentifier(aci))).thenReturn(true); + when(account.isIdentifiedBy(new PniServiceIdentifier(pni))).thenReturn(true); Metrics.globalRegistry.clear(); simpleMeterRegistry = new SimpleMeterRegistry(); Metrics.globalRegistry.add(simpleMeterRegistry); @@ -49,46 +56,46 @@ class MessageMetricsTest { @Test void measureAccountOutgoingMessageUuidMismatches() { - final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(aci); + final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(new AciServiceIdentifier(aci)); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToAci); Optional counter = findCounter(simpleMeterRegistry); assertTrue(counter.isEmpty()); - final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(pni); + final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(new PniServiceIdentifier(pni)); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToPni); counter = findCounter(simpleMeterRegistry); assertTrue(counter.isEmpty()); - final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(otherUuid); + final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(new AciServiceIdentifier(otherUuid)); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToOtherUuid); counter = findCounter(simpleMeterRegistry); assertEquals(1.0, counter.map(Counter::count).orElse(0.0)); } - private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { - return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false, null); + private OutgoingMessageEntity createOutgoingMessageEntity(final ServiceIdentifier destinationIdentifier) { + return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationIdentifier, null, new byte[]{}, 1, true, false, null); } @Test void measureAccountEnvelopeUuidMismatches() { - final MessageProtos.Envelope envelopeToAci = createEnvelope(aci); + final MessageProtos.Envelope envelopeToAci = createEnvelope(new AciServiceIdentifier(aci)); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToAci); Optional counter = findCounter(simpleMeterRegistry); assertTrue(counter.isEmpty()); - final MessageProtos.Envelope envelopeToPni = createEnvelope(pni); + final MessageProtos.Envelope envelopeToPni = createEnvelope(new PniServiceIdentifier(pni)); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToPni); counter = findCounter(simpleMeterRegistry); assertTrue(counter.isEmpty()); - final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(otherUuid); + final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(new AciServiceIdentifier(otherUuid)); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToOtherUuid); counter = findCounter(simpleMeterRegistry); @@ -101,11 +108,11 @@ class MessageMetricsTest { assertEquals(1.0, counter.map(Counter::count).orElse(0.0)); } - private MessageProtos.Envelope createEnvelope(UUID destinationUuid) { + private MessageProtos.Envelope createEnvelope(ServiceIdentifier destinationIdentifier) { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); - if (destinationUuid != null) { - builder.setDestinationUuid(destinationUuid.toString()); + if (destinationIdentifier != null) { + builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString()); } return builder.build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index d4b00a8d7..19cec8183 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -61,6 +62,8 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -208,6 +211,21 @@ class AccountsManagerTest { mock(Clock.class)); } + @Test + void testGetByServiceIdentifier() { + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString()); + when(commands.get(eq("Account3::" + aci))).thenReturn( + "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"); + + assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent()); + assertTrue(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni)).isPresent()); + assertFalse(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(pni)).isPresent()); + assertFalse(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(aci)).isPresent()); + } + @Test void testGetAccountByNumberInCache() { UUID uuid = UUID.randomUUID(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 00f1e5ee4..1595c4cec 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -55,6 +55,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -225,7 +226,7 @@ class WebSocketConnectionTest { verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId), eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); - verify(receiptSender, times(1)).sendReceipt(eq(accountUuid), eq(deviceId), eq(senderOneUuid), + verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)), eq(2222L)); connection.stop(); @@ -369,7 +370,7 @@ class WebSocketConnectionTest { futures.get(1).complete(response); futures.get(0).completeExceptionally(new IOException()); - verify(receiptSender, times(1)).sendReceipt(eq(account.getUuid()), eq(deviceId), eq(senderTwoUuid), + verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), eq(secondMessage.getTimestamp())); connection.stop();