Group Send Credential support in chat

This commit is contained in:
Jonathan Klabunde Tomer 2024-01-04 11:38:57 -08:00 committed by GitHub
parent 195f23c347
commit e1ad25cee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 337 additions and 108 deletions

View File

@ -812,7 +812,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender, new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, pushNotificationManager, reportMessageManager, accountsManager, messagesManager, pushNotificationManager, reportMessageManager,
multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager, multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager,
dynamicConfigurationManager), dynamicConfigurationManager, zkSecretParams),
new PaymentsController(currencyManager, paymentsCredentialsGenerator), new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager, new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,

View File

@ -0,0 +1,26 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Base64;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response.Status;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
public record GroupSendCredentialHeader(GroupSendCredentialPresentation presentation) {
public static GroupSendCredentialHeader valueOf(String header) {
try {
return new GroupSendCredentialHeader(new GroupSendCredentialPresentation(Base64.getDecoder().decode(header)));
} catch (InvalidInputException | IllegalArgumentException e) {
// Base64 throws IllegalArgumentException; GroupSendCredentialPresentation ctor throws InvalidInputException
throw new WebApplicationException(e, Status.UNAUTHORIZED);
}
}
}

View File

@ -18,8 +18,6 @@ import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class OptionalAccess { public class OptionalAccess {
public static final String UNIDENTIFIED = "Unidentified-Access-Key";
public static void verify(Optional<Account> requestAccount, public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey, Optional<Anonymous> accessKey,
Optional<Account> targetAccount, Optional<Account> targetAccount,

View File

@ -61,6 +61,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -206,7 +207,7 @@ public class KeysController {
name = "Retry-After", name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed")) description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth, public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for") @Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") ServiceIdentifier targetIdentifier, @PathParam("identifier") ServiceIdentifier targetIdentifier,

View File

@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
@ -52,6 +54,7 @@ import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue; import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.NotFoundException; import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST; import javax.ws.rs.POST;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
@ -67,13 +70,17 @@ import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.util.Pair; import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.GroupSendCredentialHeader;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -113,6 +120,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories; import org.whispersystems.websocket.Stories;
@ -148,6 +156,7 @@ public class MessageController {
private final ReportSpamTokenProvider reportSpamTokenProvider; private final ReportSpamTokenProvider reportSpamTokenProvider;
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ServerSecretParams serverSecretParams;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8; private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
@ -188,7 +197,8 @@ public class MessageController {
Scheduler messageDeliveryScheduler, Scheduler messageDeliveryScheduler,
@Nonnull ReportSpamTokenProvider reportSpamTokenProvider, @Nonnull ReportSpamTokenProvider reportSpamTokenProvider,
final ClientReleaseManager clientReleaseManager, final ClientReleaseManager clientReleaseManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) { final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ServerSecretParams serverSecretParams) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.messageByteLimitEstimator = messageByteLimitEstimator; this.messageByteLimitEstimator = messageByteLimitEstimator;
this.messageSender = messageSender; this.messageSender = messageSender;
@ -202,6 +212,7 @@ public class MessageController {
this.reportSpamTokenProvider = reportSpamTokenProvider; this.reportSpamTokenProvider = reportSpamTokenProvider;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.serverSecretParams = serverSecretParams;
} }
@Timed @Timed
@ -211,7 +222,7 @@ public class MessageController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@FilterSpam @FilterSpam
public Response sendMessage(@Auth Optional<AuthenticatedAccount> source, public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("destination") ServiceIdentifier destinationIdentifier, @PathParam("destination") ServiceIdentifier destinationIdentifier,
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@ -372,6 +383,7 @@ public class MessageController {
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap( private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) { SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.switchIfEmpty(Flux.error(BadRequestException::new))
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue())) .map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
.flatMap( .flatMap(
t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1())) t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1()))
@ -406,7 +418,9 @@ public class MessageController {
""") """)
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true) @ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times") @ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect") @ApiResponse(
responseCode="401",
description="The message is not a story and the unauthorized access key or group send credential is missing or incorrect")
@ApiResponse( @ApiResponse(
responseCode="404", responseCode="404",
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users") description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
@ -416,10 +430,14 @@ public class MessageController {
@ApiResponse( @ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class))) content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))
public Response sendMultiRecipientMessage( public Response sendMultiRecipientMessage(
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message") @Deprecated
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, @Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message. Will be replaced with group send credentials")
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@Parameter(description="A group send credential covering all (included and excluded) recipients of the message. Must not be combined with `Unidentified-Access-Key` or set on a story message.")
@HeaderParam(HeaderUtils.GROUP_SEND_CREDENTIAL)
@Nullable GroupSendCredentialHeader groupSendCredential,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@ -436,14 +454,31 @@ public class MessageController {
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal") @Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal")
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { @NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
if (groupSendCredential == null && accessKeys == null && !isStory) {
throw new NotAuthorizedException("A group send credential or unidentified access key is required for non-story messages");
}
if (groupSendCredential != null) {
if (accessKeys != null) {
throw new BadRequestException("Only one of group send credential and unidentified access key may be provided");
} else if (isStory) {
throw new BadRequestExcpetion("Stories should not provide a group send credential");
}
}
if (groupSendCredential != null) {
// Group send credentials are checked before we even attempt to resolve any accounts, since
// the lists of service IDs in the envelope are all that we need to check against
checkGroupSendCredential(
multiRecipientMessage.getRecipients().keySet(), multiRecipientMessage.getExcludedRecipients(), groupSendCredential);
}
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory); final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
// Stories will be checked by the client; we bypass access checks here for stories. // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
if (!isStory) { // Group send credentials are checked earlier; for stories, we don't check permissions at all because only clients check them
if (groupSendCredential == null && !isStory) {
checkAccessKeys(accessKeys, recipients.values()); checkAccessKeys(accessKeys, recipients.values());
} }
// We might filter out all the recipients of a story (if none exist). // We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now. // In this case there is no error so we should just return 200 now.
if (isStory) { if (isStory) {
@ -556,12 +591,28 @@ public class MessageController {
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
} }
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MultiRecipientDeliveryData> destinations) { private void checkGroupSendCredential(
// We should not have null access keys when checking access; bail out early. final Collection<ServiceId> recipients,
if (accessKeys == null) { final Collection<ServiceId> excludedRecipients,
throw new WebApplicationException(Status.UNAUTHORIZED); final @NotNull GroupSendCredentialHeader groupSendCredential) {
try {
// A group send credential covers *every* group member except the sender. However, clients
// don't always want to actually send to every recipient in the same multi-send (most
// commonly because a new member needs an SKDM first, but also could be because the sender
// has blocked someone). So we check the group send credential against the combination of
// the actual recipients and the supplied list of "excluded" recipients, accounts the
// sender knows are part of the credential but doesn't want to send to right now.
groupSendCredential.presentation().verify(
Lists.newArrayList(Iterables.concat(recipients, excludedRecipients)),
serverSecretParams);
} catch (VerificationFailedException e) {
throw new NotAuthorizedException(e);
} }
}
private void checkAccessKeys(
final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys,
final Collection<MultiRecipientDeliveryData> destinations) {
final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH; final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH;
final byte[] combinedUnidentifiedAccessKeys = destinations.stream() final byte[] combinedUnidentifiedAccessKeys = destinations.stream()
.map(MultiRecipientDeliveryData::account) .map(MultiRecipientDeliveryData::account)

View File

@ -94,6 +94,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ProfileHelper; import org.whispersystems.textsecuregcm.util.ProfileHelper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -226,7 +227,7 @@ public class ProfileController {
@Path("/{identifier}/{version}") @Path("/{identifier}/{version}")
public VersionedProfileResponse getProfile( public VersionedProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth, @Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version) @PathParam("version") String version)
@ -246,7 +247,7 @@ public class ProfileController {
@Path("/{identifier}/{version}/{credentialRequest}") @Path("/{identifier}/{version}/{credentialRequest}")
public CredentialProfileResponse getProfile( public CredentialProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth, @Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version, @PathParam("version") String version,
@ -276,7 +277,7 @@ public class ProfileController {
@Path("/{identifier}") @Path("/{identifier}")
public BaseProfileResponse getUnversionedProfile( public BaseProfileResponse getUnversionedProfile(
@Auth Optional<AuthenticatedAccount> auth, @Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext, @Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") ServiceIdentifier identifier, @PathParam("identifier") ServiceIdentifier identifier,

View File

@ -10,6 +10,8 @@ import java.util.Arrays;
import java.util.HexFormat; import java.util.HexFormat;
import java.util.UUID; import java.util.UUID;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
/** /**
@ -51,6 +53,11 @@ public record AciServiceIdentifier(UUID uuid) implements ServiceIdentifier {
return byteBuffer.array(); return byteBuffer.array();
} }
@Override
public ServiceId.Aci toLibsignal() {
return new ServiceId.Aci(uuid);
}
public static AciServiceIdentifier valueOf(final String string) { public static AciServiceIdentifier valueOf(final String string) {
return new AciServiceIdentifier( return new AciServiceIdentifier(
UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix()) UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix())

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.identity; package org.whispersystems.textsecuregcm.identity;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Arrays; import java.util.Arrays;
@ -51,6 +53,11 @@ public record PniServiceIdentifier(UUID uuid) implements ServiceIdentifier {
return byteBuffer.array(); return byteBuffer.array();
} }
@Override
public ServiceId.Pni toLibsignal() {
return new ServiceId.Pni(uuid);
}
public static PniServiceIdentifier valueOf(final String string) { public static PniServiceIdentifier valueOf(final String string) {
if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) { if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) {
throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix"); throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix");

View File

@ -81,4 +81,6 @@ public interface ServiceIdentifier {
} }
throw new IllegalArgumentException("unknown libsignal ServiceId type"); throw new IllegalArgumentException("unknown libsignal ServiceId type");
} }
ServiceId toLibsignal();
} }

View File

@ -24,6 +24,10 @@ public final class HeaderUtils {
public static final String TIMESTAMP_HEADER = "X-Signal-Timestamp"; public static final String TIMESTAMP_HEADER = "X-Signal-Timestamp";
public static final String UNIDENTIFIED_ACCESS_KEY = "Unidentified-Access-Key";
public static final String GROUP_SEND_CREDENTIAL = "Group-Send-Credential";
private HeaderUtils() { private HeaderUtils() {
// utility class // utility class
} }

View File

@ -40,12 +40,12 @@ import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations;
import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse; import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -198,7 +198,7 @@ class CertificateControllerTest {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery") .target("/v1/certificate/delivery")
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.get(); .get();
assertEquals(response.getStatus(), 401); assertEquals(response.getStatus(), 401);

View File

@ -48,7 +48,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest { class KeysControllerTest {
@ -494,7 +494,7 @@ class KeysControllerTest {
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true") .queryParam("pq", "true")
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
@ -518,7 +518,7 @@ class KeysControllerTest {
Response result = resources.getJerseyTest() Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(); .get();
assertThat(result).isNotNull(); assertThat(result).isNotNull();
@ -530,7 +530,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes()))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -542,7 +542,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, "$$$$$$$$$") .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, "$$$$$$$$$")
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);

View File

@ -81,10 +81,20 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets; import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.util.Hex;
import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.signal.libsignal.zkgroup.groups.GroupSecretParams;
import org.signal.libsignal.zkgroup.groups.UuidCiphertext;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse;
import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
@ -125,6 +135,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil; import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -137,15 +148,19 @@ import reactor.core.scheduler.Schedulers;
class MessageControllerTest { class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); private static final ServiceIdentifier SINGLE_DEVICE_ACI_ID = new AciServiceIdentifier(SINGLE_DEVICE_UUID);
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier SINGLE_DEVICE_PNI_ID = new PniServiceIdentifier(SINGLE_DEVICE_PNI);
private static final byte SINGLE_DEVICE_ID1 = 1; private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111; private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_ACI_ID = new AciServiceIdentifier(MULTI_DEVICE_UUID);
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_PNI_ID = new PniServiceIdentifier(MULTI_DEVICE_PNI);
private static final byte MULTI_DEVICE_ID1 = 1; private static final byte MULTI_DEVICE_ID1 = 1;
private static final byte MULTI_DEVICE_ID2 = 2; private static final byte MULTI_DEVICE_ID2 = 2;
private static final byte MULTI_DEVICE_ID3 = 3; private static final byte MULTI_DEVICE_ID3 = 3;
@ -157,6 +172,8 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444; private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.randomUUID(); private static final UUID NONEXISTENT_UUID = UUID.randomUUID();
private static final ServiceIdentifier NONEXISTENT_ACI_ID = new AciServiceIdentifier(NONEXISTENT_UUID);
private static final ServiceIdentifier NONEXISTENT_PNI_ID = new PniServiceIdentifier(NONEXISTENT_UUID);
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@ -178,6 +195,7 @@ class MessageControllerTest {
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@ -189,7 +207,8 @@ class MessageControllerTest {
.addResource( .addResource(
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager)) messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams))
.build(); .build();
@BeforeEach @BeforeEach
@ -213,19 +232,19 @@ class MessageControllerTest {
Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_ACI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_PNI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_ACI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_PNI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(NONEXISTENT_ACI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(NONEXISTENT_PNI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount))); when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
@ -381,7 +400,7 @@ class MessageControllerTest {
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
@ -904,7 +923,7 @@ class MessageControllerTest {
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList( .put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true, List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true,
System.currentTimeMillis()), System.currentTimeMillis()),
@ -965,7 +984,21 @@ class MessageControllerTest {
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
} }
private static void writeMultiPayloadExcludedRecipient(final ByteBuffer bb, final ServiceIdentifier id, final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(id.toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(id.uuid()));
}
bb.put((byte) 0);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) { private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) {
return initializeMultiPayload(recipients, List.of(), buffer, explicitIdentifiers);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, List<ServiceIdentifier> excludedRecipients, byte[] buffer, final boolean explicitIdentifiers) {
// initialize a binary payload according to our wire format // initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer); ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN); bb.order(ByteOrder.BIG_ENDIAN);
@ -974,17 +1007,15 @@ class MessageControllerTest {
bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte
// count varint // count varint
int nRecip = recipients.size(); int nRecip = recipients.size() + excludedRecipients.size();
while (nRecip > 127) { while (nRecip > 127) {
bb.put((byte) (nRecip & 0x7F | 0x80)); bb.put((byte) (nRecip & 0x7F | 0x80));
nRecip = nRecip >> 7; nRecip = nRecip >> 7;
} }
bb.put((byte)(nRecip & 0x7F)); bb.put((byte)(nRecip & 0x7F));
Iterator<Recipient> it = recipients.iterator(); recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers));
while (it.hasNext()) { excludedRecipients.forEach(recipient -> writeMultiPayloadExcludedRecipient(bb, recipient, explicitIdentifiers));
writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers);
}
// now write the actual message body (empty for now) // now write the actual message body (empty for now)
bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here) bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here)
@ -1062,8 +1093,17 @@ class MessageControllerTest {
// set up the entity to use in our PUT request // set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request // build correct or incorrect access header
Invocation.Builder bldr = resources final String accessHeader;
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
accessHeader = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
} else {
accessHeader = "BBBBBBBBBBBBBBBBBBBBBB==";
}
// make the PUT request
Response response = resources
.getJerseyTest() .getJerseyTest()
.target("/v1/messages/multi_recipient") .target("/v1/messages/multi_recipient")
.queryParam("online", true) .queryParam("online", true)
@ -1071,17 +1111,9 @@ class MessageControllerTest {
.queryParam("story", isStory) .queryParam("story", isStory)
.queryParam("urgent", urgent) .queryParam("urgent", urgent)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME"); .header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
// add access header if needed .put(entity);
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender, verify(messageSender,
@ -1105,18 +1137,18 @@ class MessageControllerTest {
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() { private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
return return
Map.of( Map.of(
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), SINGLE_DEVICE_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), SINGLE_DEVICE_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ACI_ID,
Map.of( Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_ID,
Map.of( Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2), MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), NONEXISTENT_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) NONEXISTENT_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
); );
} }
@ -1137,16 +1169,16 @@ class MessageControllerTest {
@SuppressWarnings("unused") @SuppressWarnings("unused")
private static ArgumentSets testMultiRecipientMessageNoPni() { private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap(); final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID)); final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, SINGLE_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID)); final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci = final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID)); submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci = final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap( submap(
targets, targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ACI_ID,
new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ACI_ID,
new AciServiceIdentifier(NONEXISTENT_UUID)); NONEXISTENT_ACI_ID);
final boolean auth = true; final boolean auth = true;
final boolean unauth = false; final boolean unauth = false;
@ -1186,18 +1218,18 @@ class MessageControllerTest {
private static ArgumentSets testMultiRecipientMessagePni() { private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap(); final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI)); final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap( final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI)); targets, SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI)); final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed = final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI)); submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed = final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap( submap(
targets, targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI), SINGLE_DEVICE_PNI_ID,
new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ACI_ID,
new PniServiceIdentifier(NONEXISTENT_UUID)); NONEXISTENT_PNI_ID);
final boolean auth = true; final boolean auth = true;
final boolean unauth = false; final boolean unauth = false;
@ -1232,13 +1264,113 @@ class MessageControllerTest {
.argumentsForNextParameter(false, true); // urgent .argumentsForNextParameter(false, true); // urgent
} }
@ParameterizedTest
@MethodSource
void testMultiRecipientMessageWithGroupSendCredential(
List<ServiceIdentifier> includedRecipients,
List<ServiceIdentifier> excludedRecipients,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
includedRecipients.forEach(
serviceIdentifier -> multiRecipientTargetMap().get(serviceIdentifier).forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, excludedRecipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_CREDENTIAL, validGroupSendCredentialHeader(
senderId,
List.of(senderId, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID)))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
exactly(expectedMessagesSent))
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
}
private static Stream<Arguments> testMultiRecipientMessageWithGroupSendCredential() {
return Stream.of(
// All members present in included or excluded recipients: success, deliver to included recipients only
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(), 200, 3),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 200, 1),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 200, 2),
// No included recipients: request is bad
Arguments.of(List.of(), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 400, 0),
// Some recipients both included and excluded: request is bad
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 400, 0),
// Included recipient not covered by credential: forbid
Arguments.of(List.of(NONEXISTENT_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(), 401, 0),
// Excluded recipient not covered by credential: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ACI_ID), 401, 0),
// Some recipients not in included or excluded list: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(), 401, 0),
// Substituting a PNI for an ACI is not allowed
Arguments.of(List.of(SINGLE_DEVICE_PNI_ID, MULTI_DEVICE_ACI_ID), List.of(), 401, 0));
}
private String validGroupSendCredentialHeader(AciServiceIdentifier sender, List<ServiceIdentifier> allGroupMembers) throws Exception {
final ServerPublicParams serverPublicParams = serverSecretParams.getPublicParams();
final GroupMasterKey groupMasterKey = new GroupMasterKey(new byte[32]);
final GroupSecretParams groupSecretParams = GroupSecretParams.deriveFromMasterKey(groupMasterKey);
final ClientZkGroupCipher clientZkGroupCipher = new ClientZkGroupCipher(groupSecretParams);
UuidCiphertext senderCiphertext = clientZkGroupCipher.encrypt(sender.toLibsignal());
List<UuidCiphertext> groupCiphertexts = allGroupMembers.stream()
.map(ServiceIdentifier::toLibsignal)
.map(clientZkGroupCipher::encrypt)
.collect(Collectors.toList());
GroupSendCredentialResponse credentialResponse =
GroupSendCredentialResponse.issueCredential(groupCiphertexts, senderCiphertext, serverSecretParams);
GroupSendCredential credential =
credentialResponse.receive(
allGroupMembers.stream().map(ServiceIdentifier::toLibsignal).collect(Collectors.toList()),
sender.toLibsignal(),
serverPublicParams,
groupSecretParams);
GroupSendCredentialPresentation presentation = credential.present(serverPublicParams);
return Base64.getEncoder().encodeToString(presentation.serialize());
}
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of( final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources Response response = resources
.getJerseyTest() .getJerseyTest()
@ -1249,7 +1381,7 @@ class MessageControllerTest {
.queryParam("urgent", false) .queryParam("urgent", false)
.request() .request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE)); .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 400); checkBadMultiRecipientResponse(response, 400);
@ -1266,7 +1398,7 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", unknownUUID)) .target(String.format("/v1/messages/%s", unknownUUID))
.queryParam("story", "true") .queryParam("story", "true")
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, accessBytes) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes)
.put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE));
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200))); assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
@ -1278,13 +1410,13 @@ class MessageControllerTest {
final Recipient r1; final Recipient r1;
if (known) { if (known) {
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else { } else {
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]); r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]);
} }
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); Recipient r2 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); Recipient r3 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3); List<Recipient> recipients = List.of(r1, r2, r3);
@ -1307,7 +1439,7 @@ class MessageControllerTest {
.queryParam("story", story) .queryParam("story", story)
.request() .request()
.header(HttpHeaders.USER_AGENT, "Test User Agent") .header(HttpHeaders.USER_AGENT, "Test User Agent")
.header(OptionalAccess.UNIDENTIFIED, accessBytes); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes);
// make the PUT request // make the PUT request
Response response = bldr.put(entity); Response response = bldr.put(entity);
@ -1363,7 +1495,7 @@ class MessageControllerTest {
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request // make the PUT request
final Response response = invocationBuilder.put(entity); final Response response = invocationBuilder.put(entity);
@ -1381,8 +1513,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() { private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() {
return Stream.of( return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); Arguments.of(MULTI_DEVICE_PNI_ID));
} }
@ParameterizedTest @ParameterizedTest
@ -1410,7 +1542,7 @@ class MessageControllerTest {
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request // make the PUT request
final Response response = invocationBuilder.put(entity); final Response response = invocationBuilder.put(entity);
@ -1429,8 +1561,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() { private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
return Stream.of( return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); Arguments.of(MULTI_DEVICE_PNI_ID));
} }
@ParameterizedTest @ParameterizedTest
@ -1460,7 +1592,7 @@ class MessageControllerTest {
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
doThrow(NotPushRegisteredException.class) doThrow(NotPushRegisteredException.class)
.when(messageSender).sendMessage(any(), any(), any(), anyBoolean()); .when(messageSender).sendMessage(any(), any(), any(), anyBoolean());
@ -1473,13 +1605,13 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() { private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of( return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), Arguments.of(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); Arguments.of(MULTI_DEVICE_PNI_ID, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
} }
@Test @Test
void sendMultiRecipientMessageStoryRateLimited() { void sendMultiRecipientMessageStoryRateLimited() {
final List<Recipient> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); final List<Recipient> recipients = List.of(new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
// initialize our binary payload and create an input stream // initialize our binary payload and create an input stream
byte[] buffer = new byte[2048]; byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer); // InputStream stream = initializeMultiPayload(recipientUUID, buffer);
@ -1498,7 +1630,7 @@ class MessageControllerTest {
.queryParam("urgent", true) .queryParam("urgent", true)
.request() .request()
.header(HttpHeaders.USER_AGENT, "FIXME") .header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class))) when(rateLimiter.validateAsync(any(UUID.class)))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true)));

View File

@ -75,7 +75,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext; import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations; import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -107,6 +106,7 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -279,7 +279,7 @@ class ProfileControllerTest {
final BaseProfileResponse profile = resources.getJerseyTest() final BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get(BaseProfileResponse.class); .get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY); assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY);
@ -295,7 +295,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -306,7 +306,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("/v1/profile/" + UUID.randomUUID()) .target("/v1/profile/" + UUID.randomUUID())
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -351,7 +351,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO) .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -365,7 +365,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -1140,7 +1140,7 @@ class ProfileControllerTest {
private static Stream<Arguments> testGetProfileWithExpiringProfileKeyCredential() { private static Stream<Arguments> testGetProfileWithExpiringProfileKeyCredential() {
return Stream.of( return Stream.of(
Arguments.of(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))), Arguments.of(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)))), Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO)))) Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))))
); );
@ -1184,7 +1184,7 @@ class ProfileControllerTest {
HexFormat.of().formatHex(credentialRequest.serialize()))) HexFormat.of().formatHex(credentialRequest.serialize())))
.queryParam("credentialType", "expiringProfileKey") .queryParam("credentialType", "expiringProfileKey")
.request() .request()
.headers(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))) .headers(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY))))
.get(); .get();
assertEquals(400, response.getStatus()); assertEquals(400, response.getStatus());