Introduce "service identifiers"

This commit is contained in:
Jon Chambers 2023-07-21 09:34:10 -04:00 committed by GitHub
parent 4a6c7152cf
commit abb32bd919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1304 additions and 588 deletions

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.util.Base64; import java.util.Base64;
import java.util.Objects; import java.util.Objects;
@ -50,6 +51,8 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameLinkHandle; import org.whispersystems.textsecuregcm.entities.UsernameLinkHandle;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -399,6 +402,7 @@ public class AccountController {
return accounts return accounts
.getByUsernameHash(hash) .getByUsernameHash(hash)
.map(Account::getUuid) .map(Account::getUuid)
.map(AciServiceIdentifier::new)
.map(AccountIdentifierResponse::new) .map(AccountIdentifierResponse::new)
.orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)); .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND));
} }
@ -485,21 +489,32 @@ public class AccountController {
return new EncryptedUsername(maybeEncryptedUsername.get()); return new EncryptedUsername(maybeEncryptedUsername.get());
} }
@Operation(
summary = "Check whether an account exists",
description = """
Enforced unauthenticated endpoint. Checks whether an account with a given identifier exists.
"""
)
@ApiResponse(responseCode = "200", description = "An account with the given identifier was found.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "400", description = "A client made an authenticated to this endpoint, and must not provide credentials.")
@ApiResponse(responseCode = "404", description = "An account was not found for the given identifier.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate-limited.")
@HEAD @HEAD
@Path("/account/{uuid}") @Path("/account/{identifier}")
@RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE) @RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
public Response accountExists( public Response accountExists(
@Auth final Optional<AuthenticatedAccount> authenticatedAccount, @Auth final Optional<AuthenticatedAccount> authenticatedAccount,
@PathParam("uuid") final UUID uuid) throws RateLimitExceededException {
@Parameter(description = "An ACI or PNI account identifier to check")
@PathParam("identifier") final ServiceIdentifier accountIdentifier) {
// Disallow clients from making authenticated requests to this endpoint // Disallow clients from making authenticated requests to this endpoint
requireNotAuthenticated(authenticatedAccount); requireNotAuthenticated(authenticatedAccount);
final Status status = accounts.getByAccountIdentifier(uuid) final Optional<Account> maybeAccount = accounts.getByServiceIdentifier(accountIdentifier);
.or(() -> accounts.getByPhoneNumberIdentifier(uuid))
.isPresent() ? Status.OK : Status.NOT_FOUND;
return Response.status(status).build(); return Response.status(maybeAccount.map(ignored -> Status.OK).orElse(Status.NOT_FOUND)).build();
} }
@Timed @Timed

View File

@ -57,6 +57,7 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem;
import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -207,7 +208,7 @@ public class KeysController {
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for") @Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") UUID targetUuid, @PathParam("identifier") ServiceIdentifier targetIdentifier,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices") @Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId, @PathParam("device_id") String deviceId,
@ -227,8 +228,7 @@ public class KeysController {
final Account target; final Account target;
{ {
final Optional<Account> maybeTarget = accounts.getByAccountIdentifier(targetUuid) final Optional<Account> maybeTarget = accounts.getByServiceIdentifier(targetIdentifier);
.or(() -> accounts.getByPhoneNumberIdentifier(targetUuid));
OptionalAccess.verify(account, accessKey, maybeTarget, deviceId); OptionalAccess.verify(account, accessKey, maybeTarget, deviceId);
@ -237,34 +237,39 @@ public class KeysController {
if (account.isPresent()) { if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate( rateLimiters.getPreKeysLimiter().validate(
account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetUuid account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetIdentifier.uuid()
+ "." + deviceId); + "." + deviceId);
} }
final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid);
List<Device> devices = parseDeviceId(deviceId, target); List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size()); List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
for (Device device : devices) { for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; ECSignedPreKey signedECPreKey = switch (targetIdentifier.identityType()) {
ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); case ACI -> device.getSignedPreKey();
ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).join().orElse(null); case PNI -> device.getPhoneNumberIdentitySignedPreKey();
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).join().orElse(null) : null; };
ECPreKey unsignedECPreKey = keys.takeEC(targetIdentifier.uuid(), device.getId()).join().orElse(null);
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()).join().orElse(null) : null;
compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(identifier, device.getId())); keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()));
if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ? final int registrationId = switch (targetIdentifier.identityType()) {
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : case ACI -> device.getRegistrationId();
device.getRegistrationId(); case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey)); responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
} }
} }
final IdentityKey identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); final IdentityKey identityKey = switch (targetIdentifier.identityType()) {
case ACI -> target.getIdentityKey();
case PNI -> target.getPhoneNumberIdentityKey();
};
if (responseItems.isEmpty()) { if (responseItems.isEmpty()) {
throw new WebApplicationException(Response.Status.NOT_FOUND); throw new WebApplicationException(Response.Status.NOT_FOUND);

View File

@ -23,6 +23,7 @@ import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -35,7 +36,6 @@ import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -48,6 +48,7 @@ import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue; import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST; import javax.ws.rs.POST;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
import javax.ws.rs.Path; import javax.ws.rs.Path;
@ -82,6 +83,8 @@ import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@ -183,7 +186,7 @@ public class MessageController {
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor, @HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor,
@PathParam("destination") UUID destinationUuid, @PathParam("destination") ServiceIdentifier destinationIdentifier,
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@NotNull @Valid IncomingMessageList messages, @NotNull @Valid IncomingMessageList messages,
@Context ContainerRequestContext context) throws RateLimitExceededException { @Context ContainerRequestContext context) throws RateLimitExceededException {
@ -195,7 +198,7 @@ public class MessageController {
final String senderType; final String senderType;
if (source.isPresent()) { if (source.isPresent()) {
if (source.get().getAccount().isIdentifiedBy(destinationUuid)) { if (source.get().getAccount().isIdentifiedBy(destinationIdentifier)) {
senderType = SENDER_TYPE_SELF; senderType = SENDER_TYPE_SELF;
} else { } else {
senderType = SENDER_TYPE_IDENTIFIED; senderType = SENDER_TYPE_IDENTIFIED;
@ -227,7 +230,7 @@ public class MessageController {
} }
try { try {
rateLimiters.getInboundMessageBytes().validate(destinationUuid, totalContentLength); rateLimiters.getInboundMessageBytes().validate(destinationIdentifier.uuid(), totalContentLength);
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
if (dynamicConfigurationManager.getConfiguration().getInboundMessageByteLimitConfiguration().enforceInboundLimit()) { if (dynamicConfigurationManager.getConfiguration().getInboundMessageByteLimitConfiguration().enforceInboundLimit()) {
throw e; throw e;
@ -235,13 +238,12 @@ public class MessageController {
} }
try { try {
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationUuid); boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationIdentifier);
Optional<Account> destination; Optional<Account> destination;
if (!isSyncMessage) { if (!isSyncMessage) {
destination = accountsManager.getByAccountIdentifier(destinationUuid) destination = accountsManager.getByServiceIdentifier(destinationIdentifier);
.or(() -> accountsManager.getByPhoneNumberIdentifier(destinationUuid));
} else { } else {
destination = source.map(AuthenticatedAccount::getAccount); destination = source.map(AuthenticatedAccount::getAccount);
} }
@ -288,7 +290,7 @@ public class MessageController {
messages.messages(), messages.messages(),
IncomingMessage::destinationDeviceId, IncomingMessage::destinationDeviceId,
IncomingMessage::destinationRegistrationId, IncomingMessage::destinationRegistrationId,
destination.get().getPhoneNumberIdentifier().equals(destinationUuid)); destination.get().getPhoneNumberIdentifier().equals(destinationIdentifier.uuid()));
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.online())),
@ -303,7 +305,7 @@ public class MessageController {
source, source,
destination.get(), destination.get(),
destinationDevice.get(), destinationDevice.get(),
destinationUuid, destinationIdentifier,
messages.timestamp(), messages.timestamp(),
messages.online(), messages.online(),
isStory, isStory,
@ -334,25 +336,20 @@ public class MessageController {
/** /**
* Build mapping of accounts to devices/registration IDs. * Build mapping of accounts to devices/registration IDs.
*
* @param multiRecipientMessage
* @param uuidToAccountMap
* @return
*/ */
private Map<Account, Set<Pair<Long, Integer>>> buildDeviceIdAndRegistrationIdMap( private Map<Account, Set<Pair<Long, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage, MultiRecipientMessage multiRecipientMessage,
Map<UUID, Account> uuidToAccountMap Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {
) {
return Arrays.stream(multiRecipientMessage.getRecipients()) return Arrays.stream(multiRecipientMessage.recipients())
// for normal messages, all recipients UUIDs are in the map, // for normal messages, all recipients UUIDs are in the map,
// but story messages might specify inactive UUIDs, which we // but story messages might specify inactive UUIDs, which we
// have previously filtered // have previously filtered
.filter(r -> uuidToAccountMap.containsKey(r.getUuid())) .filter(r -> accountsByServiceIdentifier.containsKey(r.uuid()))
.collect(Collectors.toMap( .collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()), recipient -> accountsByServiceIdentifier.get(recipient.uuid()),
recipient -> new HashSet<>( recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), Collections.singletonList(new Pair<>(recipient.deviceId(), recipient.registrationId()))),
(a, b) -> { (a, b) -> {
a.addAll(b); a.addAll(b);
return a; return a;
@ -376,33 +373,29 @@ public class MessageController {
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) { @NotNull @Valid MultiRecipientMessage multiRecipientMessage) {
// we skip "missing" accounts when story=true. final Map<ServiceIdentifier, Account> accountsByServiceIdentifier = new HashMap<>();
// otherwise, we return a 404 status code.
final Function<UUID, Stream<Account>> accountFinder = uuid -> {
Optional<Account> res = accountsManager.getByAccountIdentifier(uuid);
if (!isStory && res.isEmpty()) {
throw new WebApplicationException(Status.NOT_FOUND);
}
return res.stream();
};
// build a map from UUID to accounts for (final Recipient recipient : multiRecipientMessage.recipients()) {
Map<UUID, Account> uuidToAccountMap = if (!accountsByServiceIdentifier.containsKey(recipient.uuid())) {
Arrays.stream(multiRecipientMessage.getRecipients()) final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(recipient.uuid());
.map(Recipient::getUuid)
.distinct() if (maybeAccount.isPresent()) {
.flatMap(accountFinder) accountsByServiceIdentifier.put(recipient.uuid(), maybeAccount.get());
.collect(Collectors.toUnmodifiableMap( } else {
Account::getUuid, if (!isStory) {
Function.identity())); throw new NotFoundException();
}
}
}
}
// Stories will be checked by the client; we bypass access checks here for stories. // Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) { if (!isStory) {
checkAccessKeys(accessKeys, uuidToAccountMap); checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
} }
final Map<Account, Set<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap = final Map<Account, Set<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, uuidToAccountMap); buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);
// We might filter out all the recipients of a story (if none have enabled stories). // We might filter out all the recipients of a story (if none have enabled stories).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
@ -412,7 +405,7 @@ public class MessageController {
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
uuidToAccountMap.values().forEach(account -> { accountsByServiceIdentifier.forEach((serviceIdentifier, account) -> {
if (isStory) { if (isStory) {
checkStoryRateLimit(account); checkStoryRateLimit(account);
@ -434,10 +427,10 @@ public class MessageController {
accountToDeviceIdAndRegistrationIdMap.get(account).stream(), accountToDeviceIdAndRegistrationIdMap.get(account).stream(),
false); false);
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), accountMismatchedDevices.add(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))); new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())));
} catch (StaleDevicesException e) { } catch (StaleDevicesException e) {
accountStaleDevices.add(new AccountStaleDevices(account.getUuid(), new StaleDevices(e.getStaleDevices()))); accountStaleDevices.add(new AccountStaleDevices(serviceIdentifier, new StaleDevices(e.getStaleDevices())));
} }
}); });
if (!accountMismatchedDevices.isEmpty()) { if (!accountMismatchedDevices.isEmpty()) {
@ -455,7 +448,7 @@ public class MessageController {
.build(); .build();
} }
List<UUID> uuids404 = Collections.synchronizedList(new ArrayList<>()); List<ServiceIdentifier> uuids404 = Collections.synchronizedList(new ArrayList<>());
try { try {
final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of( final Counter sentMessageCounter = Metrics.counter(SENT_MESSAGE_COUNTER_NAME, Tags.of(
@ -463,18 +456,18 @@ public class MessageController {
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED))); Tag.of(SENDER_TYPE_TAG_NAME, SENDER_TYPE_UNIDENTIFIED)));
multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.getRecipients()) multiRecipientMessageExecutor.invokeAll(Arrays.stream(multiRecipientMessage.recipients())
.map(recipient -> (Callable<Void>) () -> { .map(recipient -> (Callable<Void>) () -> {
Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); Account destinationAccount = accountsByServiceIdentifier.get(recipient.uuid());
// we asserted this must exist in validateCompleteDeviceList // we asserted this must exist in validateCompleteDeviceList
Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
sentMessageCounter.increment(); sentMessageCounter.increment();
try { try {
sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent, sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, isUrgent,
recipient, multiRecipientMessage.getCommonPayload()); recipient, multiRecipientMessage.commonPayload());
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
uuids404.add(destinationAccount.getUuid()); uuids404.add(recipient.uuid());
} }
return null; return null;
}) })
@ -486,7 +479,7 @@ public class MessageController {
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
} }
private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) { private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<Account> destinationAccounts) {
// We should not have null access keys when checking access; bail out early. // We should not have null access keys when checking access; bail out early.
if (accessKeys == null) { if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED); throw new WebApplicationException(Status.UNAUTHORIZED);
@ -494,7 +487,7 @@ public class MessageController {
AtomicBoolean throwUnauthorized = new AtomicBoolean(false); AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
byte[] empty = new byte[16]; byte[] empty = new byte[16];
final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]);
byte[] combinedUnknownAccessKeys = uuidToAccountMap.values().stream() byte[] combinedUnknownAccessKeys = destinationAccounts.stream()
.map(account -> { .map(account -> {
if (account.isUnrestrictedUnidentifiedAccess()) { if (account.isUnrestrictedUnidentifiedAccess()) {
return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY; return UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY;
@ -595,8 +588,8 @@ public class MessageController {
if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) { if (deletedMessage.hasSourceUuid() && deletedMessage.getType() != Type.SERVER_DELIVERY_RECEIPT) {
try { try {
receiptSender.sendReceipt( receiptSender.sendReceipt(
UUID.fromString(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(), ServiceIdentifier.valueOf(deletedMessage.getDestinationUuid()), auth.getAuthenticatedDevice().getId(),
UUID.fromString(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp()); AciServiceIdentifier.valueOf(deletedMessage.getSourceUuid()), deletedMessage.getTimestamp());
} catch (Exception e) { } catch (Exception e) {
logger.warn("Failed to send delivery receipt", e); logger.warn("Failed to send delivery receipt", e);
} }
@ -663,7 +656,7 @@ public class MessageController {
Optional<AuthenticatedAccount> source, Optional<AuthenticatedAccount> source,
Account destinationAccount, Account destinationAccount,
Device destinationDevice, Device destinationDevice,
UUID destinationUuid, ServiceIdentifier destinationIdentifier,
long timestamp, long timestamp,
boolean online, boolean online,
boolean story, boolean story,
@ -679,7 +672,7 @@ public class MessageController {
Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null); Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null);
Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
envelope = incomingMessage.toEnvelope( envelope = incomingMessage.toEnvelope(
destinationUuid, destinationIdentifier,
sourceAccount, sourceAccount,
sourceDeviceId, sourceDeviceId,
timestamp == 0 ? System.currentTimeMillis() : timestamp, timestamp == 0 ? System.currentTimeMillis() : timestamp,
@ -709,10 +702,10 @@ public class MessageController {
try { try {
Envelope.Builder messageBuilder = Envelope.newBuilder(); Envelope.Builder messageBuilder = Envelope.newBuilder();
long serverTimestamp = System.currentTimeMillis(); long serverTimestamp = System.currentTimeMillis();
byte[] recipientKeyMaterial = recipient.getPerRecipientKeyMaterial(); byte[] recipientKeyMaterial = recipient.perRecipientKeyMaterial();
byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length]; byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length];
payload[0] = MultiRecipientMessageProvider.VERSION; payload[0] = MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER;
System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length); System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length);
System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length); System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length);
@ -723,7 +716,7 @@ public class MessageController {
.setContent(ByteString.copyFrom(payload)) .setContent(ByteString.copyFrom(payload))
.setStory(story) .setStory(story)
.setUrgent(urgent) .setUrgent(urgent)
.setDestinationUuid(destinationAccount.getUuid().toString()); .setDestinationUuid(new AciServiceIdentifier(destinationAccount.getUuid()).toServiceIdentifierString());
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
} catch (NotPushRegisteredException e) { } catch (NotPushRegisteredException e) {

View File

@ -90,6 +90,9 @@ import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialPro
import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes; import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes;
import org.whispersystems.textsecuregcm.entities.UserCapabilities; import org.whispersystems.textsecuregcm.entities.UserCapabilities;
import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
@ -234,33 +237,33 @@ public class ProfileController {
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}") @Path("/{identifier}/{version}")
public VersionedProfileResponse getProfile( public VersionedProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth, @Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("uuid") UUID uuid, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version) @PathParam("version") String version)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount);
final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, uuid); final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, accountIdentifier);
return buildVersionedProfileResponse(targetAccount, return buildVersionedProfileResponse(targetAccount,
version, version,
isSelfProfileRequest(maybeRequester, uuid), isSelfProfileRequest(maybeRequester, accountIdentifier),
containerRequestContext); containerRequestContext);
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}/{credentialRequest}") @Path("/{identifier}/{version}/{credentialRequest}")
public CredentialProfileResponse getProfile( public CredentialProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth, @Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("uuid") UUID uuid, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version, @PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest, @PathParam("credentialRequest") String credentialRequest,
@QueryParam("credentialType") String credentialType) @QueryParam("credentialType") String credentialType)
@ -271,8 +274,8 @@ public class ProfileController {
} }
final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount);
final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, uuid); final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, accountIdentifier);
final boolean isSelf = isSelfProfileRequest(maybeRequester, uuid); final boolean isSelf = isSelfProfileRequest(maybeRequester, accountIdentifier);
return buildExpiringProfileKeyCredentialProfileResponse(targetAccount, return buildExpiringProfileKeyCredentialProfileResponse(targetAccount,
version, version,
@ -293,34 +296,38 @@ public class ProfileController {
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") UUID identifier, @PathParam("identifier") ServiceIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate) @QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException { throws RateLimitExceededException {
final Optional<Account> maybeAccountByPni = accountsManager.getByPhoneNumberIdentifier(identifier);
final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount); final Optional<Account> maybeRequester = auth.map(AuthenticatedAccount::getAccount);
final BaseProfileResponse profileResponse; return switch (identifier.identityType()) {
case ACI -> {
final AciServiceIdentifier aciServiceIdentifier = (AciServiceIdentifier) identifier;
if (maybeAccountByPni.isPresent()) { final Account targetAccount =
if (maybeRequester.isEmpty()) { verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, aciServiceIdentifier);
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} else { yield buildBaseProfileResponseForAccountIdentity(targetAccount,
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); isSelfProfileRequest(maybeRequester, aciServiceIdentifier),
containerRequestContext);
} }
case PNI -> {
final Optional<Account> maybeAccountByPni = accountsManager.getByPhoneNumberIdentifier(identifier.uuid());
OptionalAccess.verify(maybeRequester, Optional.empty(), maybeAccountByPni); if (maybeRequester.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} else {
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());
}
profileResponse = buildBaseProfileResponseForPhoneNumberIdentity(maybeAccountByPni.get()); OptionalAccess.verify(maybeRequester, Optional.empty(), maybeAccountByPni);
} else {
final Account targetAccount = verifyPermissionToReceiveAccountIdentityProfile(maybeRequester, accessKey, identifier);
profileResponse = buildBaseProfileResponseForAccountIdentity(targetAccount, assert maybeAccountByPni.isPresent();
isSelfProfileRequest(maybeRequester, identifier), yield buildBaseProfileResponseForPhoneNumberIdentity(maybeAccountByPni.get());
containerRequestContext); }
} };
return profileResponse;
} }
@Timed @Timed
@ -363,35 +370,24 @@ public class ProfileController {
private void checkFingerprintAndAdd(BatchIdentityCheckRequest.Element element, private void checkFingerprintAndAdd(BatchIdentityCheckRequest.Element element,
Collection<BatchIdentityCheckResponse.Element> responseElements, MessageDigest md) { Collection<BatchIdentityCheckResponse.Element> responseElements, MessageDigest md) {
final Optional<Account> maybeAccount; final Optional<Account> maybeAccount = accountsManager.getByServiceIdentifier(element.uuid());
final boolean usePhoneNumberIdentity;
if (element.aci() != null) {
maybeAccount = accountsManager.getByAccountIdentifier(element.aci());
usePhoneNumberIdentity = false;
} else {
final Optional<Account> maybeAciAccount = accountsManager.getByAccountIdentifier(element.uuid());
if (maybeAciAccount.isEmpty()) {
maybeAccount = accountsManager.getByPhoneNumberIdentifier(element.uuid());
usePhoneNumberIdentity = true;
} else {
maybeAccount = maybeAciAccount;
usePhoneNumberIdentity = false;
}
}
maybeAccount.ifPresent(account -> { maybeAccount.ifPresent(account -> {
if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) { if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) {
return; return;
} }
final IdentityKey identityKey =
usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); final IdentityKey identityKey = switch (element.uuid().identityType()) {
case ACI -> account.getIdentityKey();
case PNI -> account.getPhoneNumberIdentityKey();
};
md.reset(); md.reset();
byte[] digest = md.digest(identityKey.serialize()); byte[] digest = md.digest(identityKey.serialize());
byte[] fingerprint = Util.truncate(digest, 4); byte[] fingerprint = Util.truncate(digest, 4);
if (!Arrays.equals(fingerprint, element.fingerprint())) { if (!Arrays.equals(fingerprint, element.fingerprint())) {
responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKey)); responseElements.add(new BatchIdentityCheckResponse.Element(element.uuid(), identityKey));
} }
}); });
} }
@ -454,7 +450,7 @@ public class ProfileController {
getAcceptableLanguagesForRequest(containerRequestContext), getAcceptableLanguagesForRequest(containerRequestContext),
account.getBadges(), account.getBadges(),
isSelf), isSelf),
account.getUuid()); new AciServiceIdentifier(account.getUuid()));
} }
private BaseProfileResponse buildBaseProfileResponseForPhoneNumberIdentity(final Account account) { private BaseProfileResponse buildBaseProfileResponseForPhoneNumberIdentity(final Account account) {
@ -463,7 +459,7 @@ public class ProfileController {
false, false,
UserCapabilities.createForAccount(account), UserCapabilities.createForAccount(account),
Collections.emptyList(), Collections.emptyList(),
account.getPhoneNumberIdentifier()); new PniServiceIdentifier(account.getPhoneNumberIdentifier()));
} }
private ExpiringProfileKeyCredentialResponse getExpiringProfileKeyCredentialResponse( private ExpiringProfileKeyCredentialResponse getExpiringProfileKeyCredentialResponse(
@ -562,7 +558,7 @@ public class ProfileController {
* *
* @param maybeRequester the authenticated account requesting the profile, if any * @param maybeRequester the authenticated account requesting the profile, if any
* @param maybeAccessKey an anonymous access key for the target account * @param maybeAccessKey an anonymous access key for the target account
* @param targetUuid the ACI of the target account * @param accountIdentifier the ACI of the target account
* *
* @return the target account * @return the target account
* *
@ -573,7 +569,7 @@ public class ProfileController {
*/ */
private Account verifyPermissionToReceiveAccountIdentityProfile(final Optional<Account> maybeRequester, private Account verifyPermissionToReceiveAccountIdentityProfile(final Optional<Account> maybeRequester,
final Optional<Anonymous> maybeAccessKey, final Optional<Anonymous> maybeAccessKey,
final UUID targetUuid) throws RateLimitExceededException { final AciServiceIdentifier accountIdentifier) throws RateLimitExceededException {
if (maybeRequester.isEmpty() && maybeAccessKey.isEmpty()) { if (maybeRequester.isEmpty() && maybeAccessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
@ -583,7 +579,7 @@ public class ProfileController {
rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid()); rateLimiters.getProfileLimiter().validate(maybeRequester.get().getUuid());
} }
final Optional<Account> maybeTargetAccount = accountsManager.getByAccountIdentifier(targetUuid); final Optional<Account> maybeTargetAccount = accountsManager.getByAccountIdentifier(accountIdentifier.uuid());
OptionalAccess.verify(maybeRequester, maybeAccessKey, maybeTargetAccount); OptionalAccess.verify(maybeRequester, maybeAccessKey, maybeTargetAccount);
assert maybeTargetAccount.isPresent(); assert maybeTargetAccount.isPresent();
@ -591,7 +587,7 @@ public class ProfileController {
return maybeTargetAccount.get(); return maybeTargetAccount.get();
} }
private boolean isSelfProfileRequest(final Optional<Account> maybeRequester, final UUID targetUuid) { private boolean isSelfProfileRequest(final Optional<Account> maybeRequester, final AciServiceIdentifier targetIdentifier) {
return maybeRequester.map(requester -> requester.getUuid().equals(targetUuid)).orElse(false); return maybeRequester.map(requester -> requester.getUuid().equals(targetIdentifier.uuid())).orElse(false);
} }
} }

View File

@ -4,7 +4,14 @@
*/ */
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.NotNull;
import java.util.UUID;
public record AccountIdentifierResponse(@NotNull UUID uuid) {} import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public record AccountIdentifierResponse(@NotNull
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonDeserialize(using = ServiceIdentifierAdapter.AciServiceIdentifierDeserializer.class)
AciServiceIdentifier uuid) {}

View File

@ -5,22 +5,14 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.UUID; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public class AccountMismatchedDevices { public record AccountMismatchedDevices(@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonProperty @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
public final UUID uuid; ServiceIdentifier uuid,
@JsonProperty MismatchedDevices devices) {
public final MismatchedDevices devices;
public String toString() {
return "AccountMismatchedDevices(" + uuid + ", " + devices + ")";
}
public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) {
this.uuid = uuid;
this.devices = devices;
}
} }

View File

@ -5,22 +5,14 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.UUID; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public class AccountStaleDevices { public record AccountStaleDevices(@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonProperty @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
public final UUID uuid; ServiceIdentifier uuid,
@JsonProperty StaleDevices devices) {
public final StaleDevices devices;
public String toString() {
return "AccountStaleDevices(" + uuid + ", " + devices + ")";
}
public AccountStaleDevices(final UUID uuid, final StaleDevices devices) {
this.uuid = uuid;
this.devices = devices;
}
} }

View File

@ -9,11 +9,11 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import java.util.List; import java.util.List;
import java.util.UUID;
public class BaseProfileResponse { public class BaseProfileResponse {
@ -35,7 +35,9 @@ public class BaseProfileResponse {
private List<Badge> badges; private List<Badge> badges;
@JsonProperty @JsonProperty
private UUID uuid; @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
private ServiceIdentifier uuid;
public BaseProfileResponse() { public BaseProfileResponse() {
} }
@ -45,7 +47,7 @@ public class BaseProfileResponse {
final boolean unrestrictedUnidentifiedAccess, final boolean unrestrictedUnidentifiedAccess,
final UserCapabilities capabilities, final UserCapabilities capabilities,
final List<Badge> badges, final List<Badge> badges,
final UUID uuid) { final ServiceIdentifier uuid) {
this.identityKey = identityKey; this.identityKey = identityKey;
this.unidentifiedAccess = unidentifiedAccess; this.unidentifiedAccess = unidentifiedAccess;
@ -75,7 +77,7 @@ public class BaseProfileResponse {
return badges; return badges;
} }
public UUID getUuid() { public ServiceIdentifier getUuid() {
return uuid; return uuid;
} }
} }

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.List; import java.util.List;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size; import javax.validation.constraints.Size;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ExactlySize; import org.whispersystems.textsecuregcm.util.ExactlySize;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List<Element> elements) { public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List<Element> elements) {
@ -20,18 +22,13 @@ public record BatchIdentityCheckRequest(@Valid @NotNull @Size(max = 1000) List<E
* @param fingerprint most significant 4 bytes of SHA-256 of the 33-byte identity key field (32-byte curve25519 public * @param fingerprint most significant 4 bytes of SHA-256 of the 33-byte identity key field (32-byte curve25519 public
* key prefixed with 0x05) * key prefixed with 0x05)
*/ */
public record Element(@Deprecated @Nullable UUID aci, public record Element(@NotNull
@Nullable UUID uuid, @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@NotNull @ExactlySize(4) byte[] fingerprint) { @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
public Element { @NotNull
if (aci == null && uuid == null) { @ExactlySize(4)
throw new IllegalArgumentException("aci and uuid cannot both be null"); byte[] fingerprint) {
}
if (aci != null && uuid != null) {
throw new IllegalArgumentException("aci and uuid cannot both be non-null");
}
}
} }
} }

View File

@ -6,39 +6,27 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import java.util.List;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public record BatchIdentityCheckResponse(@Valid List<Element> elements) { public record BatchIdentityCheckResponse(@Valid List<Element> elements) {
public record Element(@Deprecated public record Element(@JsonInclude(JsonInclude.Include.NON_EMPTY)
@JsonInclude(JsonInclude.Include.NON_EMPTY) @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@Nullable UUID aci, @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
@NotNull
@JsonInclude(JsonInclude.Include.NON_EMPTY) ServiceIdentifier uuid,
@Nullable UUID uuid,
@NotNull @NotNull
@JsonSerialize(using = IdentityKeyAdapter.Serializer.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
IdentityKey identityKey) { IdentityKey identityKey) {
public Element {
if (aci == null && uuid == null) {
throw new IllegalArgumentException("aci and uuid cannot both be null");
}
if (aci != null && uuid != null) {
throw new IllegalArgumentException("aci and uuid cannot both be non-null");
}
}
} }
} }

View File

@ -6,14 +6,15 @@ package org.whispersystems.textsecuregcm.entities;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.Base64; import java.util.Base64;
import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
public record IncomingMessage(int type, long destinationDeviceId, int destinationRegistrationId, String content) { public record IncomingMessage(int type, long destinationDeviceId, int destinationRegistrationId, String content) {
public MessageProtos.Envelope toEnvelope(final UUID destinationUuid, public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
@Nullable Account sourceAccount, @Nullable Account sourceAccount,
@Nullable Long sourceDeviceId, @Nullable Long sourceDeviceId,
final long timestamp, final long timestamp,
@ -32,13 +33,13 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio
envelopeBuilder.setType(envelopeType) envelopeBuilder.setType(envelopeType)
.setTimestamp(timestamp) .setTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setDestinationUuid(destinationUuid.toString()) .setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setStory(story) .setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) { if (sourceAccount != null && sourceDeviceId != null) {
envelopeBuilder envelopeBuilder
.setSourceUuid(sourceAccount.getUuid().toString()) .setSourceUuid(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue()); .setSourceDevice(sourceDeviceId.intValue());
} }

View File

@ -6,31 +6,15 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List; import java.util.List;
public class MismatchedDevices { public record MismatchedDevices(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
@JsonProperty List<Long> missingDevices,
@Schema(description = "Devices present on the account but absent in the request")
public List<Long> missingDevices;
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
public List<Long> extraDevices;
@VisibleForTesting
public MismatchedDevices() {}
public String toString() {
return "MismatchedDevices(" + missingDevices + ", " + extraDevices + ")";
}
public MismatchedDevices(List<Long> missingDevices, List<Long> extraDevices) {
this.missingDevices = missingDevices;
this.extraDevices = extraDevices;
}
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
List<Long> extraDevices) {
} }

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.entities;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import java.util.Arrays; import java.util.Arrays;
import java.util.UUID; import java.util.Objects;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.Max; import javax.validation.constraints.Max;
@ -16,58 +16,33 @@ import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size; import javax.validation.constraints.Size;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public class MultiRecipientMessage { public record MultiRecipientMessage(
@NotNull @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) @Valid Recipient[] recipients,
@NotNull @Size(min = 32) byte[] commonPayload) {
private static final Counter REJECT_DUPLICATE_RECIPIENT_COUNTER = private static final Counter REJECT_DUPLICATE_RECIPIENT_COUNTER =
Metrics.counter( Metrics.counter(
name(MessageController.class, "rejectDuplicateRecipients"), name(MessageController.class, "rejectDuplicateRecipients"),
"multiRecipient", "false"); "multiRecipient", "false");
public static class Recipient { public record Recipient(@NotNull
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@NotNull @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
private final UUID uuid; ServiceIdentifier uuid,
@Min(1) long deviceId,
@Min(1) @Min(0) @Max(65535) int registrationId,
private final long deviceId; @Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) {
@Min(0)
@Max(65535)
private final int registrationId;
@Size(min = 48, max = 48)
@NotNull
private final byte[] perRecipientKeyMaterial;
public Recipient(UUID uuid, long deviceId, int registrationId, byte[] perRecipientKeyMaterial) {
this.uuid = uuid;
this.deviceId = deviceId;
this.registrationId = registrationId;
this.perRecipientKeyMaterial = perRecipientKeyMaterial;
}
public UUID getUuid() {
return uuid;
}
public long getDeviceId() {
return deviceId;
}
public int getRegistrationId() {
return registrationId;
}
public byte[] getPerRecipientKeyMaterial() {
return perRecipientKeyMaterial;
}
@Override @Override
public boolean equals(final Object o) { public boolean equals(final Object o) {
@ -75,60 +50,48 @@ public class MultiRecipientMessage {
return true; return true;
if (o == null || getClass() != o.getClass()) if (o == null || getClass() != o.getClass())
return false; return false;
Recipient recipient = (Recipient) o; Recipient recipient = (Recipient) o;
return deviceId == recipient.deviceId && registrationId == recipient.registrationId && uuid.equals(recipient.uuid)
if (deviceId != recipient.deviceId) && Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial);
return false;
if (registrationId != recipient.registrationId)
return false;
if (!uuid.equals(recipient.uuid))
return false;
return Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial);
} }
@Override @Override
public int hashCode() { public int hashCode() {
int result = uuid.hashCode(); int result = Objects.hash(uuid, deviceId, registrationId);
result = 31 * result + (int) (deviceId ^ (deviceId >>> 32));
result = 31 * result + registrationId;
result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial); result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial);
return result; return result;
} }
public String toString() {
return "Recipient(" + uuid + ", " + deviceId + ", " + registrationId + ", " + Arrays.toString(perRecipientKeyMaterial) + ")";
}
} }
@NotNull
@Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT)
@Valid
private final Recipient[] recipients;
@NotNull
@Size(min = 32)
private final byte[] commonPayload;
public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) { public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) {
this.recipients = recipients; this.recipients = recipients;
this.commonPayload = commonPayload; this.commonPayload = commonPayload;
} }
public Recipient[] getRecipients() {
return recipients;
}
public byte[] getCommonPayload() {
return commonPayload;
}
@AssertTrue @AssertTrue
public boolean hasNoDuplicateRecipients() { public boolean hasNoDuplicateRecipients() {
boolean valid = Arrays.stream(recipients).map(r -> new Pair<>(r.getUuid(), r.getDeviceId())).distinct().count() == recipients.length; boolean valid =
Arrays.stream(recipients).map(r -> new Pair<>(r.uuid(), r.deviceId())).distinct().count() == recipients.length;
if (!valid) { if (!valid) {
REJECT_DUPLICATE_RECIPIENT_COUNTER.increment(); REJECT_DUPLICATE_RECIPIENT_COUNTER.increment();
} }
return valid; return valid;
} }
@Override
public boolean equals(final Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
MultiRecipientMessage that = (MultiRecipientMessage) o;
return Arrays.equals(recipients, that.recipients) && Arrays.equals(commonPayload, that.commonPayload);
}
@Override
public int hashCode() {
int result = Arrays.hashCode(recipients);
result = 31 * result + Arrays.hashCode(commonPayload);
return result;
}
} }

View File

@ -5,28 +5,50 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, public record OutgoingMessageEntity(UUID guid,
UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, int type,
long serverTimestamp, boolean urgent, boolean story, @Nullable byte[] reportSpamToken) { long timestamp,
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
@Nullable
ServiceIdentifier sourceUuid,
int sourceDevice,
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier destinationUuid,
@Nullable UUID updatedPni,
byte[] content,
long serverTimestamp,
boolean urgent,
boolean story,
@Nullable byte[] reportSpamToken) {
public MessageProtos.Envelope toEnvelope() { public MessageProtos.Envelope toEnvelope() {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type())) .setType(MessageProtos.Envelope.Type.forNumber(type()))
.setTimestamp(timestamp()) .setTimestamp(timestamp())
.setServerTimestamp(serverTimestamp()) .setServerTimestamp(serverTimestamp())
.setDestinationUuid(destinationUuid().toString()) .setDestinationUuid(destinationUuid().toServiceIdentifierString())
.setServerGuid(guid().toString()) .setServerGuid(guid().toString())
.setStory(story) .setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceUuid() != null) { if (sourceUuid() != null) {
builder.setSourceUuid(sourceUuid().toString()); builder.setSourceUuid(sourceUuid().toServiceIdentifierString());
builder.setSourceDevice(sourceDevice()); builder.setSourceDevice(sourceDevice());
} }
@ -51,9 +73,9 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab
UUID.fromString(envelope.getServerGuid()), UUID.fromString(envelope.getServerGuid()),
envelope.getType().getNumber(), envelope.getType().getNumber(),
envelope.getTimestamp(), envelope.getTimestamp(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, envelope.hasSourceUuid() ? ServiceIdentifier.valueOf(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(), envelope.getSourceDevice(),
envelope.hasDestinationUuid() ? UUID.fromString(envelope.getDestinationUuid()) : null, envelope.hasDestinationUuid() ? ServiceIdentifier.valueOf(envelope.getDestinationUuid()) : null,
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.getContent().toByteArray(), envelope.getContent().toByteArray(),
envelope.getServerTimestamp(), envelope.getServerTimestamp(),

View File

@ -6,27 +6,15 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
public class SendMultiRecipientMessageResponse { public record SendMultiRecipientMessageResponse(@JsonSerialize(contentUsing = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonProperty @JsonDeserialize(contentUsing = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
private List<UUID> uuids404; List<ServiceIdentifier> uuids404) {
public SendMultiRecipientMessageResponse() {
}
public String toString() {
return "SendMultiRecipientMessageResponse(" + uuids404 + ")";
}
@VisibleForTesting
public List<UUID> getUUIDs404() {
return this.uuids404;
}
public SendMultiRecipientMessageResponse(final List<UUID> uuids404) {
this.uuids404 = uuids404;
}
} }

View File

@ -10,20 +10,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List; import java.util.List;
public class StaleDevices { public record StaleDevices(@JsonProperty
@Schema(description = "Devices that are no longer active")
@JsonProperty List<Long> staleDevices) {
@Schema(description = "Devices that are no longer active")
private List<Long> staleDevices;
public StaleDevices() {}
public String toString() {
return "StaleDevices(" + staleDevices + ")";
}
public StaleDevices(List<Long> staleDevices) {
this.staleDevices = staleDevices;
}
} }

View File

@ -0,0 +1,75 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.UUID;
import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* An identifier for an account based on the account's ACI.
*
* @param uuid the account's ACI UUID
*/
@Schema(
type = "string",
description = "An identifier for an account based on the account's ACI"
)
public record AciServiceIdentifier(UUID uuid) implements ServiceIdentifier {
private static final IdentityType IDENTITY_TYPE = IdentityType.ACI;
@Override
public IdentityType identityType() {
return IDENTITY_TYPE;
}
@Override
public String toServiceIdentifierString() {
return uuid.toString();
}
@Override
public byte[] toCompactByteArray() {
return UUIDUtil.toBytes(uuid);
}
@Override
public byte[] toFixedWidthByteArray() {
final ByteBuffer byteBuffer = ByteBuffer.allocate(17);
byteBuffer.put(IDENTITY_TYPE.getBytePrefix());
byteBuffer.putLong(uuid.getMostSignificantBits());
byteBuffer.putLong(uuid.getLeastSignificantBits());
byteBuffer.flip();
return byteBuffer.array();
}
public static AciServiceIdentifier valueOf(final String string) {
return new AciServiceIdentifier(
UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix())
? string.substring(IDENTITY_TYPE.getStringPrefix().length()) : string));
}
public static AciServiceIdentifier fromBytes(final byte[] bytes) {
final UUID uuid;
if (bytes.length == 17) {
if (bytes[0] != IDENTITY_TYPE.getBytePrefix()) {
throw new IllegalArgumentException("Unexpected byte array prefix: " + HexFormat.of().formatHex(new byte[] { bytes[0] }));
}
uuid = UUIDUtil.fromBytes(Arrays.copyOfRange(bytes, 1, bytes.length));
} else {
uuid = UUIDUtil.fromBytes(bytes);
}
return new AciServiceIdentifier(uuid);
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
public enum IdentityType {
ACI((byte) 0x00, "ACI:"),
PNI((byte) 0x01, "PNI:");
private final byte bytePrefix;
private final String stringPrefix;
IdentityType(final byte bytePrefix, final String stringPrefix) {
this.bytePrefix = bytePrefix;
this.stringPrefix = stringPrefix;
}
byte getBytePrefix() {
return bytePrefix;
}
String getStringPrefix() {
return stringPrefix;
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.UUID;
/**
* An identifier for an account based on the account's phone number identifier (PNI).
*
* @param uuid the account's PNI UUID
*/
@Schema(
type = "string",
description = "An identifier for an account based on the account's phone number identifier (PNI)"
)
public record PniServiceIdentifier(UUID uuid) implements ServiceIdentifier {
private static final IdentityType IDENTITY_TYPE = IdentityType.PNI;
@Override
public IdentityType identityType() {
return IDENTITY_TYPE;
}
@Override
public String toServiceIdentifierString() {
return IDENTITY_TYPE.getStringPrefix() + uuid.toString();
}
@Override
public byte[] toCompactByteArray() {
return toFixedWidthByteArray();
}
@Override
public byte[] toFixedWidthByteArray() {
final ByteBuffer byteBuffer = ByteBuffer.allocate(17);
byteBuffer.put(IDENTITY_TYPE.getBytePrefix());
byteBuffer.putLong(uuid.getMostSignificantBits());
byteBuffer.putLong(uuid.getLeastSignificantBits());
byteBuffer.flip();
return byteBuffer.array();
}
public static PniServiceIdentifier valueOf(final String string) {
if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) {
throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix");
}
return new PniServiceIdentifier(UUID.fromString(string.substring(IDENTITY_TYPE.getStringPrefix().length())));
}
public static PniServiceIdentifier fromBytes(final byte[] bytes) {
if (bytes.length == 17) {
if (bytes[0] != IDENTITY_TYPE.getBytePrefix()) {
throw new IllegalArgumentException("Unexpected byte array prefix: " + HexFormat.of().formatHex(new byte[] { bytes[0] }));
}
return new PniServiceIdentifier(UUIDUtil.fromBytes(Arrays.copyOfRange(bytes, 1, bytes.length)));
}
throw new IllegalArgumentException("Unexpected byte array length: " + bytes.length);
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.UUID;
/**
* A "service identifier" is a tuple of a UUID and identity type that identifies an account and identity within the
* Signal service.
*/
@Schema(
type = "string",
description = "A service identifier is a tuple of a UUID and identity type that identifies an account and identity within the Signal service.",
subTypes = {AciServiceIdentifier.class, PniServiceIdentifier.class}
)
public interface ServiceIdentifier {
/**
* Returns the identity type of this account identifier.
*
* @return the identity type of this account identifier
*/
IdentityType identityType();
/**
* Returns the UUID for this account identifier.
*
* @return the UUID for this account identifier
*/
UUID uuid();
/**
* Returns a string representation of this account identifier in a format that clients can unambiguously resolve into
* an identity type and UUID.
*
* @return a "strongly-typed" string representation of this account identifier
*/
String toServiceIdentifierString();
/**
* Returns a compact binary representation of this account identifier.
*
* @return a binary representation of this account identifier
*/
byte[] toCompactByteArray();
/**
* Returns a fixed-width binary representation of this account identifier.
*
* @return a binary representation of this account identifier
*/
byte[] toFixedWidthByteArray();
static ServiceIdentifier valueOf(final String string) {
try {
return AciServiceIdentifier.valueOf(string);
} catch (final IllegalArgumentException e) {
return PniServiceIdentifier.valueOf(string);
}
}
static ServiceIdentifier fromBytes(final byte[] bytes) {
try {
return AciServiceIdentifier.fromBytes(bytes);
} catch (final IllegalArgumentException e) {
return PniServiceIdentifier.fromBytes(bytes);
}
}
}

View File

@ -9,22 +9,20 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public final class MessageMetrics { public final class MessageMetrics {
@ -44,16 +42,15 @@ public final class MessageMetrics {
final MessageProtos.Envelope envelope) { final MessageProtos.Envelope envelope) {
if (envelope.hasDestinationUuid()) { if (envelope.hasDestinationUuid()) {
try { try {
final UUID destinationUuid = UUID.fromString(envelope.getDestinationUuid()); measureAccountDestinationUuidMismatches(account, ServiceIdentifier.valueOf(envelope.getDestinationUuid()));
measureAccountDestinationUuidMismatches(account, destinationUuid);
} catch (final IllegalArgumentException ignored) { } catch (final IllegalArgumentException ignored) {
logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid()); logger.warn("Envelope had invalid destination UUID: {}", envelope.getDestinationUuid());
} }
} }
} }
private static void measureAccountDestinationUuidMismatches(final Account account, final UUID destinationUuid) { private static void measureAccountDestinationUuidMismatches(final Account account, final ServiceIdentifier destinationIdentifier) {
if (!destinationUuid.equals(account.getUuid()) && !destinationUuid.equals(account.getPhoneNumberIdentifier())) { if (!account.isIdentifiedBy(destinationIdentifier)) {
// In all cases, this represents a mismatch between the accounts current PNI and its PNI when the message was // In all cases, this represents a mismatch between the accounts current PNI and its PNI when the message was
// sent. This is an expected case, but if this metric changes significantly, it could indicate an issue to // sent. This is an expected case, but if this metric changes significantly, it could indicate an issue to
// investigate. // investigate.

View File

@ -11,7 +11,6 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.util.UUID;
import javax.ws.rs.BadRequestException; import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
@ -21,6 +20,7 @@ import javax.ws.rs.core.NoContentException;
import javax.ws.rs.ext.MessageBodyReader; import javax.ws.rs.ext.MessageBodyReader;
import javax.ws.rs.ext.Provider; import javax.ws.rs.ext.Provider;
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
@Provider @Provider
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
@ -29,7 +29,30 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm"; public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
public static final int MAX_RECIPIENT_COUNT = 5000; public static final int MAX_RECIPIENT_COUNT = 5000;
public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256)); public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256));
public static final byte VERSION = 0x22;
public static final byte AMBIGUOUS_ID_VERSION_IDENTIFIER = 0x22;
public static final byte EXPLICIT_ID_VERSION_IDENTIFIER = 0x23;
private enum Version {
AMBIGUOUS_ID(AMBIGUOUS_ID_VERSION_IDENTIFIER),
EXPLICIT_ID(EXPLICIT_ID_VERSION_IDENTIFIER);
private final byte identifier;
Version(final byte identifier) {
this.identifier = identifier;
}
static Version forVersionByte(final byte versionByte) {
for (final Version version : values()) {
if (version.identifier == versionByte) {
return version;
}
}
throw new IllegalArgumentException("Unrecognized version byte: " + versionByte);
}
}
@Override @Override
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) { public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
@ -44,23 +67,29 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
if (versionByte == -1) { if (versionByte == -1) {
throw new NoContentException("Empty body not allowed"); throw new NoContentException("Empty body not allowed");
} }
if (versionByte != VERSION) {
final Version version;
try {
version = Version.forVersionByte((byte) versionByte);
} catch (final IllegalArgumentException e) {
throw new BadRequestException("Unsupported version"); throw new BadRequestException("Unsupported version");
} }
long count = readVarint(entityStream); long count = readVarint(entityStream);
if (count > MAX_RECIPIENT_COUNT) { if (count > MAX_RECIPIENT_COUNT) {
throw new BadRequestException("Maximum recipient count exceeded"); throw new BadRequestException("Maximum recipient count exceeded");
} }
MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)]; MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)];
for (int i = 0; i < Math.toIntExact(count); i++) { for (int i = 0; i < Math.toIntExact(count); i++) {
UUID uuid = readUuid(entityStream); ServiceIdentifier identifier = readIdentifier(entityStream, version);
long deviceId = readVarint(entityStream); long deviceId = readVarint(entityStream);
int registrationId = readU16(entityStream); int registrationId = readU16(entityStream);
byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
if (perRecipientKeyMaterial.length != 48) { if (perRecipientKeyMaterial.length != 48) {
throw new IOException("Failed to read expected number of key material bytes for a recipient"); throw new IOException("Failed to read expected number of key material bytes for a recipient");
} }
recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, registrationId, perRecipientKeyMaterial); recipients[i] = new MultiRecipientMessage.Recipient(identifier, deviceId, registrationId, perRecipientKeyMaterial);
} }
// caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than
@ -73,32 +102,15 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
} }
/** /**
* Reads a UUID in network byte order and converts to a UUID object. * Reads a service identifier from the given stream.
*/ */
private UUID readUuid(InputStream stream) throws IOException { private ServiceIdentifier readIdentifier(final InputStream stream, final Version version) throws IOException {
byte[] buffer = new byte[8]; final byte[] uuidBytes = switch (version) {
case AMBIGUOUS_ID -> stream.readNBytes(16);
case EXPLICIT_ID -> stream.readNBytes(17);
};
int read = stream.readNBytes(buffer, 0, 8); return ServiceIdentifier.fromBytes(uuidBytes);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long msb = convertNetworkByteOrderToLong(buffer);
read = stream.readNBytes(buffer, 0, 8);
if (read != 8) {
throw new IOException("Insufficient bytes for UUID");
}
long lsb = convertNetworkByteOrderToLong(buffer);
return new UUID(msb, lsb);
}
private long convertNetworkByteOrderToLong(byte[] buffer) {
long result = 0;
for (int i = 0; i < 8; i++) {
result = (result << 8) | (buffer[i] & 0xFFL);
}
return result;
} }
/** /**

View File

@ -9,11 +9,12 @@ import com.codahale.metrics.InstrumentedExecutorService;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics; import io.micrometer.core.instrument.binder.jvm.ExecutorServiceMetrics;
import java.util.UUID;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -40,20 +41,20 @@ public class ReceiptSender {
; ;
} }
public void sendReceipt(UUID sourceUuid, long sourceDeviceId, UUID destinationUuid, long messageId) { public void sendReceipt(ServiceIdentifier sourceIdentifier, long sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) {
if (sourceUuid.equals(destinationUuid)) { if (sourceIdentifier.equals(destinationIdentifier)) {
return; return;
} }
executor.submit(() -> { executor.submit(() -> {
try { try {
accountManager.getByAccountIdentifier(destinationUuid).ifPresentOrElse( accountManager.getByAccountIdentifier(destinationIdentifier.uuid()).ifPresentOrElse(
destinationAccount -> { destinationAccount -> {
final Envelope.Builder message = Envelope.newBuilder() final Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setSourceUuid(sourceUuid.toString()) .setSourceUuid(sourceIdentifier.toServiceIdentifierString())
.setSourceDevice((int) sourceDeviceId) .setSourceDevice((int) sourceDeviceId)
.setDestinationUuid(destinationUuid.toString()) .setDestinationUuid(destinationIdentifier.toServiceIdentifierString())
.setTimestamp(messageId) .setTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT) .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false); .setUrgent(false);
@ -68,7 +69,7 @@ public class ReceiptSender {
} }
} }
}, },
() -> logger.info("No longer registered: {}", destinationUuid) () -> logger.info("No longer registered: {}", destinationIdentifier)
); );
} catch (final Exception e) { } catch (final Exception e) {

View File

@ -25,6 +25,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
@ -123,13 +124,17 @@ public class Account {
} }
/** /**
* Tests whether this account's account identifier or phone number identifier matches the given UUID. * Tests whether this account's account identifier or phone number identifier (depending on the given service
* identifier's identity type) matches the given service identifier.
* *
* @param identifier the identifier to test * @param serviceIdentifier the identifier to test
* @return {@code true} if this account's identifier or phone number identifier matches * @return {@code true} if this account's identifier or phone number identifier matches
*/ */
public boolean isIdentifiedBy(final UUID identifier) { public boolean isIdentifiedBy(final ServiceIdentifier serviceIdentifier) {
return uuid.equals(identifier) || (phoneNumberIdentifier != null && phoneNumberIdentifier.equals(identifier)); return switch (serviceIdentifier.identityType()) {
case ACI -> serviceIdentifier.uuid().equals(uuid);
case PNI -> serviceIdentifier.uuid().equals(phoneNumberIdentifier);
};
} }
public String getNumber() { public String getNumber() {

View File

@ -52,6 +52,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
@ -803,6 +804,13 @@ public class AccountsManager {
); );
} }
public Optional<Account> getByServiceIdentifier(final ServiceIdentifier serviceIdentifier) {
return switch (serviceIdentifier.identityType()) {
case ACI -> getByAccountIdentifier(serviceIdentifier.uuid());
case PNI -> getByPhoneNumberIdentifier(serviceIdentifier.uuid());
};
}
public Optional<Account> getByAccountIdentifier(final UUID uuid) { public Optional<Account> getByAccountIdentifier(final UUID uuid) {
return checkRedisThenAccounts( return checkRedisThenAccounts(
getByUuidTimer, getByUuidTimer,

View File

@ -24,6 +24,7 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
@ -136,9 +137,9 @@ public class ChangeNumberManager {
.setType(Envelope.Type.forNumber(message.type())) .setType(Envelope.Type.forNumber(message.type()))
.setTimestamp(serverTimestamp) .setTimestamp(serverTimestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setDestinationUuid(sourceAndDestinationAccount.getUuid().toString()) .setDestinationUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setContent(ByteString.copyFrom(contents.get())) .setContent(ByteString.copyFrom(contents.get()))
.setSourceUuid(sourceAndDestinationAccount.getUuid().toString()) .setSourceUuid(new AciServiceIdentifier(sourceAndDestinationAccount.getUuid()).toServiceIdentifierString())
.setSourceDevice((int) Device.MASTER_ID) .setSourceDevice((int) Device.MASTER_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString()) .setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true) .setUrgent(true)

View File

@ -0,0 +1,60 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import java.io.IOException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
public class ServiceIdentifierAdapter {
public static class ServiceIdentifierSerializer extends JsonSerializer<ServiceIdentifier> {
@Override
public void serialize(final ServiceIdentifier identifier, final JsonGenerator jsonGenerator, final SerializerProvider serializers)
throws IOException {
jsonGenerator.writeString(identifier.toServiceIdentifierString());
}
}
public static class AciServiceIdentifierDeserializer extends JsonDeserializer<AciServiceIdentifier> {
@Override
public AciServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context)
throws IOException {
return AciServiceIdentifier.valueOf(parser.getValueAsString());
}
}
public static class PniServiceIdentifierDeserializer extends JsonDeserializer<PniServiceIdentifier> {
@Override
public PniServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context)
throws IOException {
return PniServiceIdentifier.valueOf(parser.getValueAsString());
}
}
public static class ServiceIdentifierDeserializer extends JsonDeserializer<ServiceIdentifier> {
@Override
public ServiceIdentifier deserialize(final JsonParser parser, final DeserializationContext context)
throws IOException {
return ServiceIdentifier.valueOf(parser.getValueAsString());
}
}
}

View File

@ -40,6 +40,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@ -265,8 +267,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
} }
try { try {
receiptSender.sendReceipt(UUID.fromString(message.getDestinationUuid()), receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationUuid()),
auth.getAuthenticatedDevice().getId(), UUID.fromString(message.getSourceUuid()), auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceUuid()),
message.getTimestamp()); message.getTimestamp());
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceUuid()); logger.error("Could not parse UUID: {}", message.getSourceUuid());

View File

@ -70,6 +70,8 @@ import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter; import org.whispersystems.textsecuregcm.limits.RateLimitByIpFilter;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -869,10 +871,9 @@ class AccountControllerTest {
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID();
when(accountsManager.getByAccountIdentifier(any())).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(accountIdentifier))).thenReturn(Optional.of(account));
when(accountsManager.getByPhoneNumberIdentifier(any())).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(phoneNumberIdentifier))).thenReturn(Optional.of(account));
when(accountsManager.getByPhoneNumberIdentifier(phoneNumberIdentifier)).thenReturn(Optional.of(account));
when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(mock(RateLimiter.class)); when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(mock(RateLimiter.class));
@ -884,7 +885,7 @@ class AccountControllerTest {
.getStatus()).isEqualTo(200); .getStatus()).isEqualTo(200);
assertThat(resources.getJerseyTest() assertThat(resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", phoneNumberIdentifier)) .target(String.format("/v1/accounts/account/PNI:%s", phoneNumberIdentifier))
.request() .request()
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.head() .head()
@ -954,7 +955,7 @@ class AccountControllerTest {
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.get(); .get();
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.readEntity(AccountIdentifierResponse.class).uuid()).isEqualTo(uuid); assertThat(response.readEntity(AccountIdentifierResponse.class).uuid().uuid()).isEqualTo(uuid);
} }
@Test @Test

View File

@ -57,6 +57,8 @@ import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
@ -77,7 +79,6 @@ class KeysControllerTest {
private static final UUID EXISTS_UUID = UUID.randomUUID(); private static final UUID EXISTS_UUID = UUID.randomUUID();
private static final UUID EXISTS_PNI = UUID.randomUUID(); private static final UUID EXISTS_PNI = UUID.randomUUID();
private static final String NOT_EXISTS_NUMBER = "+14152222220";
private static final UUID NOT_EXISTS_UUID = UUID.randomUUID(); private static final UUID NOT_EXISTS_UUID = UUID.randomUUID();
private static final int SAMPLE_REGISTRATION_ID = 999; private static final int SAMPLE_REGISTRATION_ID = 999;
@ -212,12 +213,10 @@ class KeysControllerTest {
when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER); when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER);
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes())); when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.getByE164(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount)); when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accounts.getByAccountIdentifier(EXISTS_UUID)).thenReturn(Optional.of(existsAccount));
when(accounts.getByPhoneNumberIdentifier(EXISTS_PNI)).thenReturn(Optional.of(existsAccount));
when(accounts.getByE164(NOT_EXISTS_NUMBER)).thenReturn(Optional.empty()); when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByAccountIdentifier(NOT_EXISTS_UUID)).thenReturn(Optional.empty()); when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
@ -384,7 +383,7 @@ class KeysControllerTest {
@Test @Test
void validSingleRequestByPhoneNumberIdentifierTestV2() { void validSingleRequestByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI)) .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
@ -404,7 +403,7 @@ class KeysControllerTest {
@Test @Test
void validSingleRequestPqByPhoneNumberIdentifierTestV2() { void validSingleRequestPqByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI)) .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.queryParam("pq", "true") .queryParam("pq", "true")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
@ -428,7 +427,7 @@ class KeysControllerTest {
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty()); when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI)) .target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
@ -451,7 +450,7 @@ class KeysControllerTest {
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString()); doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString());
Response result = resources.getJerseyTest() Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_PNI)) .target(String.format("/v2/keys/PNI:%s/*", EXISTS_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(); .get();

View File

@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
@ -29,6 +30,7 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@ -42,11 +44,13 @@ import java.nio.ByteOrder;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -69,6 +73,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
@ -78,6 +83,8 @@ import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
@ -91,11 +98,15 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
@ -111,6 +122,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.websocket.Stories; import org.whispersystems.websocket.Stories;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
@ -191,11 +203,11 @@ class MessageControllerTest {
Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByAccountIdentifier(eq(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(MULTI_DEVICE_PNI)).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
final DynamicDeliveryLatencyConfiguration deliveryLatencyConfiguration = mock(DynamicDeliveryLatencyConfiguration.class); final DynamicDeliveryLatencyConfiguration deliveryLatencyConfiguration = mock(DynamicDeliveryLatencyConfiguration.class);
when(deliveryLatencyConfiguration.instrumentedVersions()).thenReturn(Collections.emptyMap()); when(deliveryLatencyConfiguration.instrumentedVersions()).thenReturn(Collections.emptyMap());
@ -310,7 +322,7 @@ class MessageControllerTest {
void testSingleDeviceCurrentByPni() throws Exception { void testSingleDeviceCurrentByPni() throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_PNI)) .target(String.format("/v1/messages/PNI:%s", SINGLE_DEVICE_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
@ -471,7 +483,7 @@ class MessageControllerTest {
void testMultiDeviceByPni() throws Exception { void testMultiDeviceByPni() throws Exception {
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", MULTI_DEVICE_PNI)) .target(String.format("/v1/messages/PNI:%s", MULTI_DEVICE_PNI))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"), .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_multi_device_pni.json"),
@ -543,14 +555,14 @@ class MessageControllerTest {
OutgoingMessageEntity first = messages.get(0); OutgoingMessageEntity first = messages.get(0);
assertEquals(first.timestamp(), timestampOne); assertEquals(first.timestamp(), timestampOne);
assertEquals(first.guid(), messageGuidOne); assertEquals(first.guid(), messageGuidOne);
assertEquals(first.sourceUuid(), sourceUuid); assertEquals(first.sourceUuid().uuid(), sourceUuid);
assertEquals(updatedPniOne, first.updatedPni()); assertEquals(updatedPniOne, first.updatedPni());
if (receiveStories) { if (receiveStories) {
OutgoingMessageEntity second = messages.get(1); OutgoingMessageEntity second = messages.get(1);
assertEquals(second.timestamp(), timestampTwo); assertEquals(second.timestamp(), timestampTwo);
assertEquals(second.guid(), messageGuidTwo); assertEquals(second.guid(), messageGuidTwo);
assertEquals(second.sourceUuid(), sourceUuid); assertEquals(second.sourceUuid().uuid(), sourceUuid);
assertNull(second.updatedPni()); assertNull(second.updatedPni());
} }
@ -623,8 +635,8 @@ class MessageControllerTest {
.delete(); .delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(AuthHelper.VALID_UUID), eq(1L), verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L),
eq(sourceUuid), eq(timestamp)); eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
response = resources.getJerseyTest() response = resources.getJerseyTest()
.target(String.format("/v1/messages/uuid/%s", uuid2)) .target(String.format("/v1/messages/uuid/%s", uuid2))
@ -920,28 +932,32 @@ class MessageControllerTest {
} while (x != 0); } while (x != 0);
} }
private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception { private static void writeMultiPayloadRecipient(final ByteBuffer bb, final Recipient r, final boolean useExplicitIdentifier) {
long msb = r.getUuid().getMostSignificantBits(); if (useExplicitIdentifier) {
long lsb = r.getUuid().getLeastSignificantBits(); bb.put(r.uuid().toFixedWidthByteArray());
bb.putLong(msb); // uuid (first 8 bytes) } else {
bb.putLong(lsb); // uuid (last 8 bytes) bb.put(UUIDUtil.toBytes(r.uuid().uuid()));
writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes) }
bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes)
bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes) writePayloadDeviceId(bb, r.deviceId()); // device id (1-9 bytes)
bb.putShort((short) r.registrationId()); // registration id (2 bytes)
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
} }
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception { private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) {
// initialize a binary payload according to our wire format // initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer); ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN); bb.order(ByteOrder.BIG_ENDIAN);
// first write the header // first write the header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte bb.put(explicitIdentifiers
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
bb.put((byte)recipients.size()); // count varint bb.put((byte)recipients.size()); // count varint
Iterator<Recipient> it = recipients.iterator(); Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) { while (it.hasNext()) {
writeMultiPayloadRecipient(bb, it.next()); writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers);
} }
// now write the actual message body (empty for now) // now write the actual message body (empty for now)
@ -953,22 +969,22 @@ class MessageControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent) throws Exception { void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory, boolean urgent, boolean explicitIdentifier) throws Exception {
final List<Recipient> recipients; final List<Recipient> recipients;
if (recipientUUID == MULTI_DEVICE_UUID) { if (recipientUUID == MULTI_DEVICE_UUID) {
recipients = List.of( recipients = List.of(
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]) new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48])
); );
} else { } else {
recipients = List.of(new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
} }
// initialize our binary payload and create an input stream // initialize our binary payload and create an input stream
byte[] buffer = new byte[2048]; byte[] buffer = new byte[2048];
//InputStream stream = initializeMultiPayload(recipientUUID, buffer); //InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer); InputStream stream = initializeMultiPayload(recipients, buffer, explicitIdentifier);
// set up the entity to use in our PUT request // set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@ -1058,31 +1074,48 @@ class MessageControllerTest {
// Arguments here are: recipient-UUID, is-authorized?, is-story? // Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() { private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of( return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true, true), Arguments.of(MULTI_DEVICE_UUID, false, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, true), Arguments.of(MULTI_DEVICE_UUID, false, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true), Arguments.of(SINGLE_DEVICE_UUID, false, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true), Arguments.of(SINGLE_DEVICE_UUID, false, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, true), Arguments.of(MULTI_DEVICE_UUID, true, true, true, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, true), Arguments.of(MULTI_DEVICE_UUID, true, false, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true), Arguments.of(SINGLE_DEVICE_UUID, true, true, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true), Arguments.of(SINGLE_DEVICE_UUID, true, false, true, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, false), Arguments.of(MULTI_DEVICE_UUID, false, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, false, false), Arguments.of(MULTI_DEVICE_UUID, false, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false), Arguments.of(SINGLE_DEVICE_UUID, false, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false), Arguments.of(SINGLE_DEVICE_UUID, false, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true, false), Arguments.of(MULTI_DEVICE_UUID, true, true, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, false, false), Arguments.of(MULTI_DEVICE_UUID, true, false, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false), Arguments.of(SINGLE_DEVICE_UUID, true, true, false, false),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false) Arguments.of(SINGLE_DEVICE_UUID, true, false, false, false),
Arguments.of(MULTI_DEVICE_UUID, false, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, true, true),
Arguments.of(MULTI_DEVICE_UUID, false, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, true, false, true),
Arguments.of(MULTI_DEVICE_UUID, true, false, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, true, false, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false, false, true)
); );
} }
@Test @ParameterizedTest
void testMultiRecipientRedisBombProtection() throws Exception { @ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of( final List<Recipient> recipients = List.of(
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources Response response = resources
.getJerseyTest() .getJerseyTest()
@ -1094,7 +1127,7 @@ class MessageControllerTest {
.request() .request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048]), MultiRecipientMessageProvider.MEDIA_TYPE)); .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 422); checkBadMultiRecipientResponse(response, 422);
} }
@ -1118,22 +1151,22 @@ class MessageControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known) throws Exception { void testSendMultiRecipientMessageToUnknownAccounts(boolean story, boolean known, boolean useExplicitIdentifier) {
final Recipient r1; final Recipient r1;
if (known) { if (known) {
r1 = new Recipient(SINGLE_DEVICE_UUID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else { } else {
r1 = new Recipient(UUID.randomUUID(), 999, 999, new byte[48]); r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]);
} }
Recipient r2 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(MULTI_DEVICE_UUID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3); List<Recipient> recipients = List.of(r1, r2, r3);
byte[] buffer = new byte[2048]; byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, buffer); InputStream stream = initializeMultiPayload(recipients, buffer, useExplicitIdentifier);
// set up the entity to use in our PUT request // set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
@ -1167,10 +1200,170 @@ class MessageControllerTest {
private static Stream<Arguments> testSendMultiRecipientMessageToUnknownAccounts() { private static Stream<Arguments> testSendMultiRecipientMessageToUnknownAccounts() {
return Stream.of( return Stream.of(
Arguments.of(true, true), Arguments.of(true, true, false),
Arguments.of(true, false), Arguments.of(true, false, false),
Arguments.of(false, true), Arguments.of(false, true, false),
Arguments.of(false, false)); Arguments.of(false, false, false),
Arguments.of(true, true, true),
Arguments.of(true, false, true),
Arguments.of(false, true, true),
Arguments.of(false, false, true)
);
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessageMismatchedDevices(final ServiceIdentifier serviceIdentifier)
throws JsonProcessingException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
assertEquals(409, response.getStatus());
final List<AccountMismatchedDevices> mismatchedDevices =
SystemMapper.jsonMapper().readValue(response.readEntity(String.class),
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class));
assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))),
mismatchedDevices);
}
private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessageStaleDevices(final ServiceIdentifier serviceIdentifier) throws JsonProcessingException {
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1 + 1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2 + 1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", false)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
assertEquals(410, response.getStatus());
final List<AccountStaleDevices> staleDevices =
SystemMapper.jsonMapper().readValue(response.readEntity(String.class),
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountStaleDevices.class));
assertEquals(1, staleDevices.size());
assertEquals(serviceIdentifier, staleDevices.get(0).uuid());
assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices()));
}
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
}
@ParameterizedTest
@MethodSource
void sendMultiRecipientMessage404(final ServiceIdentifier serviceIdentifier)
throws NotPushRegisteredException, InterruptedException {
when(multiRecipientMessageExecutor.invokeAll(any()))
.thenAnswer(answer -> {
final List<Callable> tasks = answer.getArgument(0, List.class);
tasks.forEach(c -> {
try {
c.call();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return null;
});
final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
final Invocation.Builder invocationBuilder = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", false)
.queryParam("ts", System.currentTimeMillis())
.queryParam("story", true)
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
doThrow(NotPushRegisteredException.class)
.when(messageSender).sendMessage(any(), any(), any(), anyBoolean());
// make the PUT request
final SendMultiRecipientMessageResponse response = invocationBuilder.put(entity, SendMultiRecipientMessageResponse.class);
assertEquals(List.of(serviceIdentifier), response.uuids404());
}
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
} }
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
@ -1185,7 +1378,7 @@ class MessageControllerTest {
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture()); verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount); assert (captor.getValue().size() == expectedCount);
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.getUUIDs404().isEmpty()); assert (smrmr.uuids404().isEmpty());
} }
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
@ -1226,7 +1419,7 @@ class MessageControllerTest {
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes); rng.nextBytes(perKeyBytes);
return new Recipient(u1, d1, dr1, perKeyBytes); return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
} }
private static void roundTripVarint(long expected, byte [] bytes) throws Exception { private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
@ -1258,8 +1451,9 @@ class MessageControllerTest {
} }
} }
@Test @ParameterizedTest
void testMultiPayloadRoundtrip() throws Exception { @ValueSource(booleans = {true, false})
void testMultiPayloadRoundtrip(final boolean useExplicitIdentifiers) throws Exception {
Random rng = new java.util.Random(); Random rng = new java.util.Random();
List<Recipient> expected = new LinkedList<>(); List<Recipient> expected = new LinkedList<>();
for(int i = 0; i < 100; i++) { for(int i = 0; i < 100; i++) {
@ -1267,11 +1461,11 @@ class MessageControllerTest {
} }
byte[] buffer = new byte[100 + expected.size() * 100]; byte[] buffer = new byte[100 + expected.size() * 100];
InputStream entityStream = initializeMultiPayload(expected, buffer); InputStream entityStream = initializeMultiPayload(expected, buffer, useExplicitIdentifiers);
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider(); MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
// the provider ignores the headers, java reflection, etc. so we don't use those here. // the provider ignores the headers, java reflection, etc. so we don't use those here.
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream); MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
List<Recipient> got = Arrays.asList(res.getRecipients()); List<Recipient> got = Arrays.asList(res.recipients());
assertEquals(expected, got); assertEquals(expected, got);
} }

View File

@ -38,7 +38,6 @@ import java.util.Collections;
import java.util.HexFormat; import java.util.HexFormat;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -90,6 +89,9 @@ import org.whispersystems.textsecuregcm.entities.CreateProfileRequest;
import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialProfileResponse; import org.whispersystems.textsecuregcm.entities.ExpiringProfileKeyCredentialProfileResponse;
import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes; import org.whispersystems.textsecuregcm.entities.ProfileAvatarUploadAttributes;
import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
@ -202,6 +204,7 @@ class ProfileControllerTest {
Account capabilitiesAccount = mock(Account.class); Account capabilitiesAccount = mock(Account.class);
when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(capabilitiesAccount.getIdentityKey()).thenReturn(ACCOUNT_IDENTITY_KEY); when(capabilitiesAccount.getIdentityKey()).thenReturn(ACCOUNT_IDENTITY_KEY);
when(capabilitiesAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY); when(capabilitiesAccount.getPhoneNumberIdentityKey()).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY);
when(capabilitiesAccount.isEnabled()).thenReturn(true); when(capabilitiesAccount.isEnabled()).thenReturn(true);
@ -209,20 +212,23 @@ class ProfileControllerTest {
when(capabilitiesAccount.isAnnouncementGroupSupported()).thenReturn(true); when(capabilitiesAccount.isAnnouncementGroupSupported()).thenReturn(true);
when(capabilitiesAccount.isChangeNumberSupported()).thenReturn(true); when(capabilitiesAccount.isChangeNumberSupported()).thenReturn(true);
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByPhoneNumberIdentifier(AuthHelper.VALID_PNI_TWO)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByPhoneNumberIdentifier(AuthHelper.VALID_PNI_TWO)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO))).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO))).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByUsernameHash(USERNAME_HASH)).thenReturn(Optional.of(profileAccount)); when(accountsManager.getByUsernameHash(USERNAME_HASH)).thenReturn(Optional.of(profileAccount));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(capabilitiesAccount)); when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount));
when(profilesManager.get(eq(AuthHelper.VALID_UUID), eq("someversion"))).thenReturn(Optional.empty()); when(profilesManager.get(eq(AuthHelper.VALID_UUID), eq("someversion"))).thenReturn(Optional.empty());
when(profilesManager.get(eq(AuthHelper.VALID_UUID_TWO), eq("validversion"))).thenReturn(Optional.of(new VersionedProfile( when(profilesManager.get(eq(AuthHelper.VALID_UUID_TWO), eq("validversion"))).thenReturn(Optional.of(new VersionedProfile(
"validversion", "validname", "profiles/validavatar", "emoji", "about", null, "validcommitmnet".getBytes()))); "validversion", "validname", "profiles/validavatar", "emoji", "about", null, "validcommitmnet".getBytes())));
when(accountsManager.getByAccountIdentifier(AuthHelper.INVALID_UUID)).thenReturn(Optional.empty());
clearInvocations(rateLimiter); clearInvocations(rateLimiter);
clearInvocations(accountsManager); clearInvocations(accountsManager);
clearInvocations(usernameRateLimiter); clearInvocations(usernameRateLimiter);
@ -308,14 +314,14 @@ class ProfileControllerTest {
@Test @Test
void testProfileGetByPni() throws RateLimitExceededException { void testProfileGetByPni() throws RateLimitExceededException {
final BaseProfileResponse profile = resources.getJerseyTest() final BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(BaseProfileResponse.class); .get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
assertThat(profile.getBadges()).isEmpty(); assertThat(profile.getBadges()).isEmpty();
assertThat(profile.getUuid()).isEqualTo(AuthHelper.VALID_PNI_TWO); assertThat(profile.getUuid()).isEqualTo(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO));
assertThat(profile.getCapabilities()).isNotNull(); assertThat(profile.getCapabilities()).isNotNull();
assertThat(profile.isUnrestrictedUnidentifiedAccess()).isFalse(); assertThat(profile.isUnrestrictedUnidentifiedAccess()).isFalse();
assertThat(profile.getUnidentifiedAccess()).isNull(); assertThat(profile.getUnidentifiedAccess()).isNull();
@ -342,7 +348,7 @@ class ProfileControllerTest {
@Test @Test
void testProfileGetByPniUnidentified() throws RateLimitExceededException { void testProfileGetByPniUnidentified() throws RateLimitExceededException {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(); .get();
@ -836,7 +842,7 @@ class ProfileControllerTest {
assertThat(profile.getAboutEmoji()).isEqualTo("emoji"); assertThat(profile.getAboutEmoji()).isEqualTo("emoji");
assertThat(profile.getAvatar()).isEqualTo("profiles/validavatar"); assertThat(profile.getAvatar()).isEqualTo("profiles/validavatar");
assertThat(profile.getBaseProfileResponse().getCapabilities().gv1Migration()).isTrue(); assertThat(profile.getBaseProfileResponse().getCapabilities().gv1Migration()).isTrue();
assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID_TWO); assertThat(profile.getBaseProfileResponse().getUuid()).isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO));
assertThat(profile.getBaseProfileResponse().getBadges()).hasSize(1).element(0).has(new Condition<>( assertThat(profile.getBaseProfileResponse().getBadges()).hasSize(1).element(0).has(new Condition<>(
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name")); badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
@ -927,7 +933,9 @@ class ProfileControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(ExpiringProfileKeyCredentialProfileResponse.class); .get(ExpiringProfileKeyCredentialProfileResponse.class);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid())
.isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID));
assertThat(profile.getCredential()).isNull(); assertThat(profile.getCredential()).isNull();
verify(zkProfileOperations, never()).issueExpiringProfileKeyCredential(any(), any(), any(), any()); verify(zkProfileOperations, never()).issueExpiringProfileKeyCredential(any(), any(), any(), any());
@ -1092,7 +1100,8 @@ class ProfileControllerTest {
.headers(authHeaders) .headers(authHeaders)
.get(ExpiringProfileKeyCredentialProfileResponse.class); .get(ExpiringProfileKeyCredentialProfileResponse.class);
assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(profile.getVersionedProfileResponse().getBaseProfileResponse().getUuid())
.isEqualTo(new AciServiceIdentifier(AuthHelper.VALID_UUID));
assertThat(profile.getCredential()).isEqualTo(credentialResponse); assertThat(profile.getCredential()).isEqualTo(credentialResponse);
verify(zkProfileOperations).issueExpiringProfileKeyCredential(credentialRequest, AuthHelper.VALID_UUID, profileKeyCommitment, expiration); verify(zkProfileOperations).issueExpiringProfileKeyCredential(credentialRequest, AuthHelper.VALID_UUID, profileKeyCommitment, expiration);
@ -1154,13 +1163,13 @@ class ProfileControllerTest {
void testBatchIdentityCheck() { void testBatchIdentityCheck() {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.json(new BatchIdentityCheckRequest(List.of( .post(Entity.json(new BatchIdentityCheckRequest(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID),
convertKeyToFingerprint(ACCOUNT_IDENTITY_KEY)), convertKeyToFingerprint(ACCOUNT_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO),
convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)), convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO),
convertKeyToFingerprint(ACCOUNT_TWO_IDENTITY_KEY)), convertKeyToFingerprint(ACCOUNT_TWO_IDENTITY_KEY)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID),
convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY)) convertKeyToFingerprint(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY))
))))) { ))))) {
assertThat(response).isNotNull(); assertThat(response).isNotNull();
@ -1170,17 +1179,14 @@ class ProfileControllerTest {
assertThat(identityCheckResponse.elements()).isNotNull().isEmpty(); assertThat(identityCheckResponse.elements()).isNotNull().isEmpty();
} }
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> { final Map<ServiceIdentifier, IdentityKey> expectedIdentityKeys = Map.of(
if (AuthHelper.VALID_UUID.equals(element.aci())) { new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY,
return Objects.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY,
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO), ACCOUNT_TWO_IDENTITY_KEY);
return Objects.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey());
} else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) { final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid =
return Objects.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey()); new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())),
} else { "is an expected UUID with the correct identity key");
return false;
}
}, "is an expected UUID with the correct identity key");
final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1189,13 +1195,13 @@ class ProfileControllerTest {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.json(new BatchIdentityCheckRequest(List.of( .post(Entity.json(new BatchIdentityCheckRequest(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID),
convertKeyToFingerprint(validAciIdentityKey)), convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO),
convertKeyToFingerprint(secondValidPniIdentityKey)), convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID_TWO),
convertKeyToFingerprint(secondValidAciIdentityKey)), convertKeyToFingerprint(secondValidAciIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID),
convertKeyToFingerprint(invalidAciIdentityKey)) convertKeyToFingerprint(invalidAciIdentityKey))
))))) { ))))) {
assertThat(response).isNotNull(); assertThat(response).isNotNull();
@ -1209,13 +1215,13 @@ class ProfileControllerTest {
} }
final List<BatchIdentityCheckRequest.Element> largeElementList = new ArrayList<>(List.of( final List<BatchIdentityCheckRequest.Element> largeElementList = new ArrayList<>(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint(validAciIdentityKey)), new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.VALID_UUID), convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint(secondValidPniIdentityKey)), new BatchIdentityCheckRequest.Element(new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint(invalidAciIdentityKey)))); new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(AuthHelper.INVALID_UUID), convertKeyToFingerprint(invalidAciIdentityKey))));
for (int i = 0; i < 900; i++) { for (int i = 0; i < 900; i++) {
largeElementList.add( largeElementList.add(
new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); new BatchIdentityCheckRequest.Element(new AciServiceIdentifier(UUID.randomUUID()), convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
} }
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
@ -1233,27 +1239,25 @@ class ProfileControllerTest {
@Test @Test
void testBatchIdentityCheckDeserialization() throws Exception { void testBatchIdentityCheckDeserialization() throws Exception {
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> { final Map<ServiceIdentifier, IdentityKey> expectedIdentityKeys = Map.of(
if (AuthHelper.VALID_UUID.equals(element.aci())) { new AciServiceIdentifier(AuthHelper.VALID_UUID), ACCOUNT_IDENTITY_KEY,
return ACCOUNT_IDENTITY_KEY.equals(element.identityKey()); new PniServiceIdentifier(AuthHelper.VALID_PNI_TWO), ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) {
return ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY.equals(element.identityKey()); final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid =
} else { new Condition<>(element -> element.identityKey().equals(expectedIdentityKeys.get(element.uuid())),
return false; "is an expected UUID with the correct identity key");
}
}, "is an expected UUID with the correct identity key");
// null properties are ok to omit // null properties are ok to omit
final String json = String.format(""" final String json = String.format("""
{ {
"elements": [ "elements": [
{ "aci": "%s", "fingerprint": "%s" },
{ "uuid": "%s", "fingerprint": "%s" }, { "uuid": "%s", "fingerprint": "%s" },
{ "aci": "%s", "fingerprint": "%s" } { "uuid": "%s", "fingerprint": "%s" },
{ "uuid": "%s", "fingerprint": "%s" }
] ]
} }
""", AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), "PNI:" + AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
@ -1277,50 +1281,34 @@ class ProfileControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testBatchIdentityCheckDeserializationBadRequest(final String json) { void testBatchIdentityCheckDeserializationBadRequest(final String json, final int expectedStatus) {
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.entity(json, "application/json"))) { .post(Entity.entity(json, "application/json"))) {
assertThat(response).isNotNull(); assertThat(response).isNotNull();
assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getStatus()).isEqualTo(expectedStatus);
} }
} }
static Stream<Arguments> testBatchIdentityCheckDeserializationBadRequest() { static Stream<Arguments> testBatchIdentityCheckDeserializationBadRequest() {
return Stream.of( return Stream.of(
Arguments.of( // aci and uuid cannot both be null Arguments.of( // aci and uuid cannot both be null
"""
{
"elements": [
{ "aci": null, "uuid": null, "fingerprint": "%s" }
]
}
"""),
Arguments.of( // an empty string is also invalid
"""
{
"elements": [
{ "aci": "", "uuid": null, "fingerprint": "%s" }
]
}
"""
),
Arguments.of( // as is a blank string
"""
{
"elements": [
{ "aci": null, "uuid": " ", "fingerprint": "%s" }
]
}
"""),
Arguments.of( // aci and uuid cannot both be non-null
String.format(""" String.format("""
{ {
"elements": [ "elements": [
{ "aci": "%s", "uuid": "%s", "fingerprint": "%s" } { "uuid": null, "fingerprint": "%s" }
] ]
} }
""", AuthHelper.VALID_UUID, AuthHelper.VALID_PNI, """, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))),
Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))))) 422),
Arguments.of( // a blank string is invalid
String.format("""
{
"elements": [
{ "uuid": " ", "fingerprint": "%s" }
]
}
""", Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))),
400)
); );
} }

View File

@ -9,20 +9,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Random; import java.util.Random;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junit.jupiter.params.provider.Arguments; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
class OutgoingMessageEntityTest { class OutgoingMessageEntityTest {
@ParameterizedTest @CartesianTest
@MethodSource @CartesianTest.MethodFactory("roundTripThroughEnvelope")
void roundTripThroughEnvelope(@Nullable final UUID sourceUuid, @Nullable final UUID updatedPni) { void roundTripThroughEnvelope(@Nullable final ServiceIdentifier sourceIdentifier,
final ServiceIdentifier destinationIdentifier,
@Nullable final UUID updatedPni) {
final byte[] messageContent = new byte[16]; final byte[] messageContent = new byte[16];
new Random().nextBytes(messageContent); new Random().nextBytes(messageContent);
@ -35,9 +39,9 @@ class OutgoingMessageEntityTest {
UUID.randomUUID(), UUID.randomUUID(),
MessageProtos.Envelope.Type.CIPHERTEXT_VALUE, MessageProtos.Envelope.Type.CIPHERTEXT_VALUE,
messageTimestamp, messageTimestamp,
UUID.randomUUID(), sourceIdentifier,
sourceUuid != null ? (int) Device.MASTER_ID : 0, sourceIdentifier != null ? (int) Device.MASTER_ID : 0,
UUID.randomUUID(), destinationIdentifier,
updatedPni, updatedPni,
messageContent, messageContent,
serverTimestamp, serverTimestamp,
@ -48,11 +52,14 @@ class OutgoingMessageEntityTest {
assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope()));
} }
private static Stream<Arguments> roundTripThroughEnvelope() { @SuppressWarnings("unused")
return Stream.of( static ArgumentSets roundTripThroughEnvelope() {
Arguments.of(UUID.randomUUID(), UUID.randomUUID()), return ArgumentSets.argumentsForFirstParameter(new AciServiceIdentifier(UUID.randomUUID()),
Arguments.of(UUID.randomUUID(), null), new PniServiceIdentifier(UUID.randomUUID()),
Arguments.of(null, UUID.randomUUID())); null)
.argumentsForNextParameter(new AciServiceIdentifier(UUID.randomUUID()),
new PniServiceIdentifier(UUID.randomUUID()))
.argumentsForNextParameter(UUID.randomUUID(), null);
} }
@Test @Test
@ -71,7 +78,7 @@ class OutgoingMessageEntityTest {
IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA"); IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA");
MessageProtos.Envelope baseEnvelope = message.toEnvelope( MessageProtos.Envelope baseEnvelope = message.toEnvelope(
UUID.randomUUID(), new AciServiceIdentifier(UUID.randomUUID()),
account, account,
123L, 123L,
System.currentTimeMillis(), System.currentTimeMillis(),

View File

@ -0,0 +1,77 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.nio.ByteBuffer;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class AciServiceIdentifierTest {
@Test
void identityType() {
assertEquals(IdentityType.ACI, new AciServiceIdentifier(UUID.randomUUID()).identityType());
}
@Test
void toServiceIdentifierString() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid.toString(), new AciServiceIdentifier(uuid).toServiceIdentifierString());
}
@Test
void toCompactByteArray() {
final UUID uuid = UUID.randomUUID();
assertArrayEquals(UUIDUtil.toBytes(uuid), new AciServiceIdentifier(uuid).toCompactByteArray());
}
@Test
void toFixedWidthByteArray() {
final UUID uuid = UUID.randomUUID();
final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17);
expectedBytesBuffer.put((byte) 0x00);
expectedBytesBuffer.putLong(uuid.getMostSignificantBits());
expectedBytesBuffer.putLong(uuid.getLeastSignificantBits());
expectedBytesBuffer.flip();
assertArrayEquals(expectedBytesBuffer.array(), new AciServiceIdentifier(uuid).toFixedWidthByteArray());
}
@Test
void valueOf() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, AciServiceIdentifier.valueOf(uuid.toString()).uuid());
assertEquals(uuid, AciServiceIdentifier.valueOf("ACI:" + uuid).uuid());
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("Not a valid UUID"));
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.valueOf("PNI:" + uuid));
}
@Test
void fromBytes() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, AciServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid)).uuid());
final byte[] prefixedBytes = new byte[17];
prefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16);
assertEquals(uuid, AciServiceIdentifier.fromBytes(prefixedBytes).uuid());
prefixedBytes[0] = 0x01;
assertThrows(IllegalArgumentException.class, () -> AciServiceIdentifier.fromBytes(prefixedBytes));
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.nio.ByteBuffer;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class PniServiceIdentifierTest {
@Test
void identityType() {
assertEquals(IdentityType.PNI, new PniServiceIdentifier(UUID.randomUUID()).identityType());
}
@Test
void toServiceIdentifierString() {
final UUID uuid = UUID.randomUUID();
assertEquals("PNI:" + uuid, new PniServiceIdentifier(uuid).toServiceIdentifierString());
}
@Test
void toByteArray() {
final UUID uuid = UUID.randomUUID();
final ByteBuffer expectedBytesBuffer = ByteBuffer.allocate(17);
expectedBytesBuffer.put((byte) 0x01);
expectedBytesBuffer.putLong(uuid.getMostSignificantBits());
expectedBytesBuffer.putLong(uuid.getLeastSignificantBits());
expectedBytesBuffer.flip();
assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toCompactByteArray());
assertArrayEquals(expectedBytesBuffer.array(), new PniServiceIdentifier(uuid).toFixedWidthByteArray());
}
@Test
void valueOf() {
final UUID uuid = UUID.randomUUID();
assertEquals(uuid, PniServiceIdentifier.valueOf("PNI:" + uuid).uuid());
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf(uuid.toString()));
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("Not a valid UUID"));
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.valueOf("ACI:" + uuid));
}
@Test
void fromBytes() {
final UUID uuid = UUID.randomUUID();
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(UUIDUtil.toBytes(uuid)));
final byte[] prefixedBytes = new byte[17];
prefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, prefixedBytes, 1, 16);
assertThrows(IllegalArgumentException.class, () -> PniServiceIdentifier.fromBytes(prefixedBytes));
prefixedBytes[0] = 0x01;
assertEquals(uuid, PniServiceIdentifier.fromBytes(prefixedBytes).uuid());
}
}

View File

@ -0,0 +1,87 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.identity;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
class ServiceIdentifierTest {
@ParameterizedTest
@MethodSource
void valueOf(final String identifierString, final IdentityType expectedIdentityType, final UUID expectedUuid) {
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.valueOf(identifierString);
assertEquals(expectedIdentityType, serviceIdentifier.identityType());
assertEquals(expectedUuid, serviceIdentifier.uuid());
}
private static Stream<Arguments> valueOf() {
final UUID uuid = UUID.randomUUID();
return Stream.of(
Arguments.of(uuid.toString(), IdentityType.ACI, uuid),
Arguments.of("ACI:" + uuid, IdentityType.ACI, uuid),
Arguments.of("PNI:" + uuid, IdentityType.PNI, uuid));
}
@ParameterizedTest
@ValueSource(strings = {"Not a valid UUID", "BAD:a9edc243-3e93-45d4-95c6-e3a84cd4a254"})
void valueOfIllegalArgument(final String identifierString) {
assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.valueOf(identifierString));
}
@ParameterizedTest
@MethodSource
void fromBytes(final byte[] bytes, final IdentityType expectedIdentityType, final UUID expectedUuid) {
final ServiceIdentifier serviceIdentifier = ServiceIdentifier.fromBytes(bytes);
assertEquals(expectedIdentityType, serviceIdentifier.identityType());
assertEquals(expectedUuid, serviceIdentifier.uuid());
}
private static Stream<Arguments> fromBytes() {
final UUID uuid = UUID.randomUUID();
final byte[] aciPrefixedBytes = new byte[17];
aciPrefixedBytes[0] = 0x00;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, aciPrefixedBytes, 1, 16);
final byte[] pniPrefixedBytes = new byte[17];
pniPrefixedBytes[0] = 0x01;
System.arraycopy(UUIDUtil.toBytes(uuid), 0, pniPrefixedBytes, 1, 16);
return Stream.of(
Arguments.of(UUIDUtil.toBytes(uuid), IdentityType.ACI, uuid),
Arguments.of(aciPrefixedBytes, IdentityType.ACI, uuid),
Arguments.of(pniPrefixedBytes, IdentityType.PNI, uuid));
}
@ParameterizedTest
@MethodSource
void fromBytesIllegalArgument(final byte[] bytes) {
assertThrows(IllegalArgumentException.class, () -> ServiceIdentifier.fromBytes(bytes));
}
private static Stream<Arguments> fromBytesIllegalArgument() {
final byte[] invalidPrefixBytes = new byte[17];
invalidPrefixBytes[0] = (byte) 0xff;
return Stream.of(
Arguments.of(new byte[0]),
Arguments.of(new byte[15]),
Arguments.of(new byte[18]),
Arguments.of(invalidPrefixBytes));
}
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.metrics;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -21,6 +22,9 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
class MessageMetricsTest { class MessageMetricsTest {
@ -35,6 +39,9 @@ class MessageMetricsTest {
void setup() { void setup() {
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni); when(account.getPhoneNumberIdentifier()).thenReturn(pni);
when(account.isIdentifiedBy(any())).thenReturn(false);
when(account.isIdentifiedBy(new AciServiceIdentifier(aci))).thenReturn(true);
when(account.isIdentifiedBy(new PniServiceIdentifier(pni))).thenReturn(true);
Metrics.globalRegistry.clear(); Metrics.globalRegistry.clear();
simpleMeterRegistry = new SimpleMeterRegistry(); simpleMeterRegistry = new SimpleMeterRegistry();
Metrics.globalRegistry.add(simpleMeterRegistry); Metrics.globalRegistry.add(simpleMeterRegistry);
@ -49,46 +56,46 @@ class MessageMetricsTest {
@Test @Test
void measureAccountOutgoingMessageUuidMismatches() { void measureAccountOutgoingMessageUuidMismatches() {
final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(aci); final OutgoingMessageEntity outgoingMessageToAci = createOutgoingMessageEntity(new AciServiceIdentifier(aci));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToAci); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToAci);
Optional<Counter> counter = findCounter(simpleMeterRegistry); Optional<Counter> counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty()); assertTrue(counter.isEmpty());
final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(pni); final OutgoingMessageEntity outgoingMessageToPni = createOutgoingMessageEntity(new PniServiceIdentifier(pni));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToPni); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToPni);
counter = findCounter(simpleMeterRegistry); counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty()); assertTrue(counter.isEmpty());
final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(otherUuid); final OutgoingMessageEntity outgoingMessageToOtherUuid = createOutgoingMessageEntity(new AciServiceIdentifier(otherUuid));
MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToOtherUuid); MessageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageToOtherUuid);
counter = findCounter(simpleMeterRegistry); counter = findCounter(simpleMeterRegistry);
assertEquals(1.0, counter.map(Counter::count).orElse(0.0)); assertEquals(1.0, counter.map(Counter::count).orElse(0.0));
} }
private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { private OutgoingMessageEntity createOutgoingMessageEntity(final ServiceIdentifier destinationIdentifier) {
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false, null); return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationIdentifier, null, new byte[]{}, 1, true, false, null);
} }
@Test @Test
void measureAccountEnvelopeUuidMismatches() { void measureAccountEnvelopeUuidMismatches() {
final MessageProtos.Envelope envelopeToAci = createEnvelope(aci); final MessageProtos.Envelope envelopeToAci = createEnvelope(new AciServiceIdentifier(aci));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToAci); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToAci);
Optional<Counter> counter = findCounter(simpleMeterRegistry); Optional<Counter> counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty()); assertTrue(counter.isEmpty());
final MessageProtos.Envelope envelopeToPni = createEnvelope(pni); final MessageProtos.Envelope envelopeToPni = createEnvelope(new PniServiceIdentifier(pni));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToPni); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToPni);
counter = findCounter(simpleMeterRegistry); counter = findCounter(simpleMeterRegistry);
assertTrue(counter.isEmpty()); assertTrue(counter.isEmpty());
final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(otherUuid); final MessageProtos.Envelope envelopeToOtherUuid = createEnvelope(new AciServiceIdentifier(otherUuid));
MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToOtherUuid); MessageMetrics.measureAccountEnvelopeUuidMismatches(account, envelopeToOtherUuid);
counter = findCounter(simpleMeterRegistry); counter = findCounter(simpleMeterRegistry);
@ -101,11 +108,11 @@ class MessageMetricsTest {
assertEquals(1.0, counter.map(Counter::count).orElse(0.0)); assertEquals(1.0, counter.map(Counter::count).orElse(0.0));
} }
private MessageProtos.Envelope createEnvelope(UUID destinationUuid) { private MessageProtos.Envelope createEnvelope(ServiceIdentifier destinationIdentifier) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder(); final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder();
if (destinationUuid != null) { if (destinationIdentifier != null) {
builder.setDestinationUuid(destinationUuid.toString()); builder.setDestinationUuid(destinationIdentifier.toServiceIdentifierString());
} }
return builder.build(); return builder.build();

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@ -61,6 +62,8 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@ -208,6 +211,21 @@ class AccountsManagerTest {
mock(Clock.class)); mock(Clock.class));
} }
@Test
void testGetByServiceIdentifier() {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(commands.get(eq("Account3::" + aci))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent());
assertTrue(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(pni)).isPresent());
assertFalse(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(pni)).isPresent());
assertFalse(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(aci)).isPresent());
}
@Test @Test
void testGetAccountByNumberInCache() { void testGetAccountByNumberInCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();

View File

@ -55,6 +55,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicDeliveryLatencyConfiguration;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -225,7 +226,7 @@ class WebSocketConnectionTest {
verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId), verify(messagesManager, times(1)).delete(eq(accountUuid), eq(deviceId),
eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp())); eq(UUID.fromString(outgoingMessages.get(1).getServerGuid())), eq(outgoingMessages.get(1).getServerTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(accountUuid), eq(deviceId), eq(senderOneUuid), verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(accountUuid)), eq(deviceId), eq(new AciServiceIdentifier(senderOneUuid)),
eq(2222L)); eq(2222L));
connection.stop(); connection.stop();
@ -369,7 +370,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response); futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException()); futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account.getUuid()), eq(deviceId), eq(senderTwoUuid), verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getTimestamp())); eq(secondMessage.getTimestamp()));
connection.stop(); connection.stop();