Group Send Credential support in chat

This commit is contained in:
Jonathan Klabunde Tomer 2024-01-04 11:38:57 -08:00 committed by GitHub
parent 195f23c347
commit e1ad25cee0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 337 additions and 108 deletions

View File

@ -812,7 +812,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, pushNotificationManager, reportMessageManager,
multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager,
dynamicConfigurationManager),
dynamicConfigurationManager, zkSecretParams),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,

View File

@ -0,0 +1,26 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Base64;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response.Status;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
public record GroupSendCredentialHeader(GroupSendCredentialPresentation presentation) {
public static GroupSendCredentialHeader valueOf(String header) {
try {
return new GroupSendCredentialHeader(new GroupSendCredentialPresentation(Base64.getDecoder().decode(header)));
} catch (InvalidInputException | IllegalArgumentException e) {
// Base64 throws IllegalArgumentException; GroupSendCredentialPresentation ctor throws InvalidInputException
throw new WebApplicationException(e, Status.UNAUTHORIZED);
}
}
}

View File

@ -18,8 +18,6 @@ import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class OptionalAccess {
public static final String UNIDENTIFIED = "Unidentified-Access-Key";
public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey,
Optional<Account> targetAccount,

View File

@ -61,6 +61,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -206,7 +207,7 @@ public class KeysController {
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") ServiceIdentifier targetIdentifier,

View File

@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.net.HttpHeaders;
import com.google.protobuf.ByteString;
import io.dropwizard.auth.Auth;
@ -52,6 +54,7 @@ import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
@ -67,13 +70,17 @@ import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
import org.signal.libsignal.protocol.util.Pair;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.GroupSendCredentialHeader;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -113,6 +120,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
@ -148,6 +156,7 @@ public class MessageController {
private final ReportSpamTokenProvider reportSpamTokenProvider;
private final ClientReleaseManager clientReleaseManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ServerSecretParams serverSecretParams;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
@ -188,7 +197,8 @@ public class MessageController {
Scheduler messageDeliveryScheduler,
@Nonnull ReportSpamTokenProvider reportSpamTokenProvider,
final ClientReleaseManager clientReleaseManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ServerSecretParams serverSecretParams) {
this.rateLimiters = rateLimiters;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.messageSender = messageSender;
@ -202,6 +212,7 @@ public class MessageController {
this.reportSpamTokenProvider = reportSpamTokenProvider;
this.clientReleaseManager = clientReleaseManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.serverSecretParams = serverSecretParams;
}
@Timed
@ -211,7 +222,7 @@ public class MessageController {
@Produces(MediaType.APPLICATION_JSON)
@FilterSpam
public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("destination") ServiceIdentifier destinationIdentifier,
@QueryParam("story") boolean isStory,
@ -372,6 +383,7 @@ public class MessageController {
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
.switchIfEmpty(Flux.error(BadRequestException::new))
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
.flatMap(
t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1()))
@ -406,7 +418,9 @@ public class MessageController {
""")
@ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true)
@ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times")
@ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect")
@ApiResponse(
responseCode="401",
description="The message is not a story and the unauthorized access key or group send credential is missing or incorrect")
@ApiResponse(
responseCode="404",
description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users")
@ -416,10 +430,14 @@ public class MessageController {
@ApiResponse(
responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices",
content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class)))
public Response sendMultiRecipientMessage(
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message")
@HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@Deprecated
@Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message. Will be replaced with group send credentials")
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@Parameter(description="A group send credential covering all (included and excluded) recipients of the message. Must not be combined with `Unidentified-Access-Key` or set on a story message.")
@HeaderParam(HeaderUtils.GROUP_SEND_CREDENTIAL)
@Nullable GroupSendCredentialHeader groupSendCredential,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@ -436,14 +454,31 @@ public class MessageController {
@QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal")
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
if (groupSendCredential == null && accessKeys == null && !isStory) {
throw new NotAuthorizedException("A group send credential or unidentified access key is required for non-story messages");
}
if (groupSendCredential != null) {
if (accessKeys != null) {
throw new BadRequestException("Only one of group send credential and unidentified access key may be provided");
} else if (isStory) {
throw new BadRequestExcpetion("Stories should not provide a group send credential");
}
}
if (groupSendCredential != null) {
// Group send credentials are checked before we even attempt to resolve any accounts, since
// the lists of service IDs in the envelope are all that we need to check against
checkGroupSendCredential(
multiRecipientMessage.getRecipients().keySet(), multiRecipientMessage.getExcludedRecipients(), groupSendCredential);
}
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
// Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above.
// Group send credentials are checked earlier; for stories, we don't check permissions at all because only clients check them
if (groupSendCredential == null && !isStory) {
checkAccessKeys(accessKeys, recipients.values());
}
// We might filter out all the recipients of a story (if none exist).
// In this case there is no error so we should just return 200 now.
if (isStory) {
@ -556,12 +591,28 @@ public class MessageController {
return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build();
}
private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection<MultiRecipientDeliveryData> destinations) {
// We should not have null access keys when checking access; bail out early.
if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED);
private void checkGroupSendCredential(
final Collection<ServiceId> recipients,
final Collection<ServiceId> excludedRecipients,
final @NotNull GroupSendCredentialHeader groupSendCredential) {
try {
// A group send credential covers *every* group member except the sender. However, clients
// don't always want to actually send to every recipient in the same multi-send (most
// commonly because a new member needs an SKDM first, but also could be because the sender
// has blocked someone). So we check the group send credential against the combination of
// the actual recipients and the supplied list of "excluded" recipients, accounts the
// sender knows are part of the credential but doesn't want to send to right now.
groupSendCredential.presentation().verify(
Lists.newArrayList(Iterables.concat(recipients, excludedRecipients)),
serverSecretParams);
} catch (VerificationFailedException e) {
throw new NotAuthorizedException(e);
}
}
private void checkAccessKeys(
final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys,
final Collection<MultiRecipientDeliveryData> destinations) {
final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH;
final byte[] combinedUnidentifiedAccessKeys = destinations.stream()
.map(MultiRecipientDeliveryData::account)

View File

@ -94,6 +94,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ProfileHelper;
import org.whispersystems.textsecuregcm.util.Util;
@ -226,7 +227,7 @@ public class ProfileController {
@Path("/{identifier}/{version}")
public VersionedProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version)
@ -246,7 +247,7 @@ public class ProfileController {
@Path("/{identifier}/{version}/{credentialRequest}")
public CredentialProfileResponse getProfile(
@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@PathParam("version") String version,
@ -276,7 +277,7 @@ public class ProfileController {
@Path("/{identifier}")
public BaseProfileResponse getUnversionedProfile(
@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@PathParam("identifier") ServiceIdentifier identifier,

View File

@ -10,6 +10,8 @@ import java.util.Arrays;
import java.util.HexFormat;
import java.util.UUID;
import io.swagger.v3.oas.annotations.media.Schema;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
@ -51,6 +53,11 @@ public record AciServiceIdentifier(UUID uuid) implements ServiceIdentifier {
return byteBuffer.array();
}
@Override
public ServiceId.Aci toLibsignal() {
return new ServiceId.Aci(uuid);
}
public static AciServiceIdentifier valueOf(final String string) {
return new AciServiceIdentifier(
UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix())

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.identity;
import io.swagger.v3.oas.annotations.media.Schema;
import org.signal.libsignal.protocol.ServiceId;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.nio.ByteBuffer;
import java.util.Arrays;
@ -51,6 +53,11 @@ public record PniServiceIdentifier(UUID uuid) implements ServiceIdentifier {
return byteBuffer.array();
}
@Override
public ServiceId.Pni toLibsignal() {
return new ServiceId.Pni(uuid);
}
public static PniServiceIdentifier valueOf(final String string) {
if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) {
throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix");

View File

@ -81,4 +81,6 @@ public interface ServiceIdentifier {
}
throw new IllegalArgumentException("unknown libsignal ServiceId type");
}
ServiceId toLibsignal();
}

View File

@ -24,6 +24,10 @@ public final class HeaderUtils {
public static final String TIMESTAMP_HEADER = "X-Signal-Timestamp";
public static final String UNIDENTIFIED_ACCESS_KEY = "Unidentified-Access-Key";
public static final String GROUP_SEND_CREDENTIAL = "Group-Send-Credential";
private HeaderUtils() {
// utility class
}

View File

@ -40,12 +40,12 @@ import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations;
import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -198,7 +198,7 @@ class CertificateControllerTest {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.get();
assertEquals(response.getStatus(), 401);

View File

@ -48,7 +48,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
@ -494,7 +494,7 @@ class KeysControllerTest {
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
@ -518,7 +518,7 @@ class KeysControllerTest {
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
assertThat(result).isNotNull();
@ -530,7 +530,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@ -542,7 +542,7 @@ class KeysControllerTest {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, "$$$$$$$$$")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, "$$$$$$$$$")
.get();
assertThat(response.getStatus()).isEqualTo(401);

View File

@ -81,10 +81,20 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.util.Hex;
import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.signal.libsignal.zkgroup.groups.GroupSecretParams;
import org.signal.libsignal.zkgroup.groups.UuidCiphertext;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse;
import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
@ -125,6 +135,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@ -137,15 +148,19 @@ import reactor.core.scheduler.Schedulers;
class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final ServiceIdentifier SINGLE_DEVICE_ACI_ID = new AciServiceIdentifier(SINGLE_DEVICE_UUID);
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier SINGLE_DEVICE_PNI_ID = new PniServiceIdentifier(SINGLE_DEVICE_PNI);
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_ACI_ID = new AciServiceIdentifier(MULTI_DEVICE_UUID);
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final ServiceIdentifier MULTI_DEVICE_PNI_ID = new PniServiceIdentifier(MULTI_DEVICE_PNI);
private static final byte MULTI_DEVICE_ID1 = 1;
private static final byte MULTI_DEVICE_ID2 = 2;
private static final byte MULTI_DEVICE_ID3 = 3;
@ -157,6 +172,8 @@ class MessageControllerTest {
private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444;
private static final UUID NONEXISTENT_UUID = UUID.randomUUID();
private static final ServiceIdentifier NONEXISTENT_ACI_ID = new AciServiceIdentifier(NONEXISTENT_UUID);
private static final ServiceIdentifier NONEXISTENT_PNI_ID = new PniServiceIdentifier(NONEXISTENT_UUID);
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
@ -178,6 +195,7 @@ class MessageControllerTest {
private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService();
private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery");
private static final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@ -189,7 +207,8 @@ class MessageControllerTest {
.addResource(
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager))
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams))
.build();
@BeforeEach
@ -213,19 +232,19 @@ class MessageControllerTest {
Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_ACI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_PNI_ID)).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_ACI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_PNI_ID)).thenReturn(Optional.of(multiDeviceAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(NONEXISTENT_ACI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifier(NONEXISTENT_PNI_ID)).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration =
@ -381,7 +400,7 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
@ -904,7 +923,7 @@ class MessageControllerTest {
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true,
System.currentTimeMillis()),
@ -965,7 +984,21 @@ class MessageControllerTest {
bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes)
}
private static void writeMultiPayloadExcludedRecipient(final ByteBuffer bb, final ServiceIdentifier id, final boolean useExplicitIdentifier) {
if (useExplicitIdentifier) {
bb.put(id.toFixedWidthByteArray());
} else {
bb.put(UUIDUtil.toBytes(id.uuid()));
}
bb.put((byte) 0);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer, final boolean explicitIdentifiers) {
return initializeMultiPayload(recipients, List.of(), buffer, explicitIdentifiers);
}
private static InputStream initializeMultiPayload(List<Recipient> recipients, List<ServiceIdentifier> excludedRecipients, byte[] buffer, final boolean explicitIdentifiers) {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
@ -974,17 +1007,15 @@ class MessageControllerTest {
bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte
// count varint
int nRecip = recipients.size();
int nRecip = recipients.size() + excludedRecipients.size();
while (nRecip > 127) {
bb.put((byte) (nRecip & 0x7F | 0x80));
nRecip = nRecip >> 7;
}
bb.put((byte)(nRecip & 0x7F));
Iterator<Recipient> it = recipients.iterator();
while (it.hasNext()) {
writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers);
}
recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers));
excludedRecipients.forEach(recipient -> writeMultiPayloadExcludedRecipient(bb, recipient, explicitIdentifiers));
// now write the actual message body (empty for now)
bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here)
@ -1062,8 +1093,17 @@ class MessageControllerTest {
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
Invocation.Builder bldr = resources
// build correct or incorrect access header
final String accessHeader;
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
accessHeader = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
} else {
accessHeader = "BBBBBBBBBBBBBBBBBBBBBB==";
}
// make the PUT request
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
@ -1071,17 +1111,9 @@ class MessageControllerTest {
.queryParam("story", isStory)
.queryParam("urgent", urgent)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME");
// add access header if needed
if (authorize) {
final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count();
String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader)
.put(entity);
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
@ -1105,18 +1137,18 @@ class MessageControllerTest {
private static Map<ServiceIdentifier, Map<Byte, Integer>> multiRecipientTargetMap() {
return
Map.of(
new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
SINGLE_DEVICE_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
SINGLE_DEVICE_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1),
MULTI_DEVICE_ACI_ID,
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2),
new PniServiceIdentifier(MULTI_DEVICE_PNI),
MULTI_DEVICE_PNI_ID,
Map.of(
MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1,
MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2),
new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
NONEXISTENT_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1),
NONEXISTENT_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1)
);
}
@ -1137,16 +1169,16 @@ class MessageControllerTest {
@SuppressWarnings("unused")
private static ArgumentSets testMultiRecipientMessageNoPni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAci = submap(targets, SINGLE_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDeviceAci = submap(targets, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsAci =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID));
submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeAci =
submap(
targets,
new AciServiceIdentifier(SINGLE_DEVICE_UUID),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new AciServiceIdentifier(NONEXISTENT_UUID));
SINGLE_DEVICE_ACI_ID,
MULTI_DEVICE_ACI_ID,
NONEXISTENT_ACI_ID);
final boolean auth = true;
final boolean unauth = false;
@ -1186,18 +1218,18 @@ class MessageControllerTest {
private static ArgumentSets testMultiRecipientMessagePni() {
final Map<ServiceIdentifier, Map<Byte, Integer>> targets = multiRecipientTargetMap();
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDevicePni = submap(targets, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> singleDeviceAciAndPni = submap(
targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI));
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI));
targets, SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> multiDevicePni = submap(targets, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> bothAccountsMixed =
submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI));
submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_PNI_ID);
final Map<ServiceIdentifier, Map<Byte, Integer>> realAndFakeMixed =
submap(
targets,
new PniServiceIdentifier(SINGLE_DEVICE_PNI),
new AciServiceIdentifier(MULTI_DEVICE_UUID),
new PniServiceIdentifier(NONEXISTENT_UUID));
SINGLE_DEVICE_PNI_ID,
MULTI_DEVICE_ACI_ID,
NONEXISTENT_PNI_ID);
final boolean auth = true;
final boolean unauth = false;
@ -1232,13 +1264,113 @@ class MessageControllerTest {
.argumentsForNextParameter(false, true); // urgent
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessageWithGroupSendCredential(
List<ServiceIdentifier> includedRecipients,
List<ServiceIdentifier> excludedRecipients,
int expectedStatus,
int expectedMessagesSent) throws Exception {
final List<Recipient> recipients = new ArrayList<>();
includedRecipients.forEach(
serviceIdentifier -> multiRecipientTargetMap().get(serviceIdentifier).forEach(
(deviceId, registrationId) ->
recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48]))));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipients, excludedRecipients, buffer, true);
final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID());
Response response = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", false)
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(HeaderUtils.GROUP_SEND_CREDENTIAL, validGroupSendCredentialHeader(
senderId,
List.of(senderId, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID)))
.put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE));
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus)));
verify(messageSender,
exactly(expectedMessagesSent))
.sendMessage(
any(),
any(),
argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()),
eq(true));
if (expectedStatus == 200) {
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assertThat(smrmr.uuids404(), is(empty()));
}
}
private static Stream<Arguments> testMultiRecipientMessageWithGroupSendCredential() {
return Stream.of(
// All members present in included or excluded recipients: success, deliver to included recipients only
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(), 200, 3),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 200, 1),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 200, 2),
// No included recipients: request is bad
Arguments.of(List.of(), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 400, 0),
// Some recipients both included and excluded: request is bad
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 400, 0),
// Included recipient not covered by credential: forbid
Arguments.of(List.of(NONEXISTENT_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(), 401, 0),
// Excluded recipient not covered by credential: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID), 401, 0),
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ACI_ID), 401, 0),
// Some recipients not in included or excluded list: forbid
Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(), 401, 0),
Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(), 401, 0),
// Substituting a PNI for an ACI is not allowed
Arguments.of(List.of(SINGLE_DEVICE_PNI_ID, MULTI_DEVICE_ACI_ID), List.of(), 401, 0));
}
private String validGroupSendCredentialHeader(AciServiceIdentifier sender, List<ServiceIdentifier> allGroupMembers) throws Exception {
final ServerPublicParams serverPublicParams = serverSecretParams.getPublicParams();
final GroupMasterKey groupMasterKey = new GroupMasterKey(new byte[32]);
final GroupSecretParams groupSecretParams = GroupSecretParams.deriveFromMasterKey(groupMasterKey);
final ClientZkGroupCipher clientZkGroupCipher = new ClientZkGroupCipher(groupSecretParams);
UuidCiphertext senderCiphertext = clientZkGroupCipher.encrypt(sender.toLibsignal());
List<UuidCiphertext> groupCiphertexts = allGroupMembers.stream()
.map(ServiceIdentifier::toLibsignal)
.map(clientZkGroupCipher::encrypt)
.collect(Collectors.toList());
GroupSendCredentialResponse credentialResponse =
GroupSendCredentialResponse.issueCredential(groupCiphertexts, senderCiphertext, serverSecretParams);
GroupSendCredential credential =
credentialResponse.receive(
allGroupMembers.stream().map(ServiceIdentifier::toLibsignal).collect(Collectors.toList()),
sender.toLibsignal(),
serverPublicParams,
groupSecretParams);
GroupSendCredentialPresentation presentation = credential.present(serverPublicParams);
return Base64.getEncoder().encodeToString(presentation.serialize());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception {
final List<Recipient> recipients = List.of(
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]));
Response response = resources
.getJerseyTest()
@ -1249,7 +1381,7 @@ class MessageControllerTest {
.queryParam("urgent", false)
.request()
.header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
checkBadMultiRecipientResponse(response, 400);
@ -1266,7 +1398,7 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", unknownUUID))
.queryParam("story", "true")
.request()
.header(OptionalAccess.UNIDENTIFIED, accessBytes)
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes)
.put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE));
assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200)));
@ -1278,13 +1410,13 @@ class MessageControllerTest {
final Recipient r1;
if (known) {
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]);
}
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
Recipient r2 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
Recipient r3 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]);
List<Recipient> recipients = List.of(r1, r2, r3);
@ -1307,7 +1439,7 @@ class MessageControllerTest {
.queryParam("story", story)
.request()
.header(HttpHeaders.USER_AGENT, "Test User Agent")
.header(OptionalAccess.UNIDENTIFIED, accessBytes);
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes);
// make the PUT request
Response response = bldr.put(entity);
@ -1363,7 +1495,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
@ -1381,8 +1513,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageMismatchedDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(MULTI_DEVICE_PNI_ID));
}
@ParameterizedTest
@ -1410,7 +1542,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
// make the PUT request
final Response response = invocationBuilder.put(entity);
@ -1429,8 +1561,8 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI)));
Arguments.of(MULTI_DEVICE_ACI_ID),
Arguments.of(MULTI_DEVICE_PNI_ID));
}
@ParameterizedTest
@ -1460,7 +1592,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
doThrow(NotPushRegisteredException.class)
.when(messageSender).sendMessage(any(), any(), any(), anyBoolean());
@ -1473,13 +1605,13 @@ class MessageControllerTest {
private static Stream<Arguments> sendMultiRecipientMessage404() {
return Stream.of(
Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
Arguments.of(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2),
Arguments.of(MULTI_DEVICE_PNI_ID, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2));
}
@Test
void sendMultiRecipientMessageStoryRateLimited() {
final List<Recipient> recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
final List<Recipient> recipients = List.of(new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]));
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer);
@ -1498,7 +1630,7 @@ class MessageControllerTest {
.queryParam("urgent", true)
.request()
.header(HttpHeaders.USER_AGENT, "FIXME")
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES));
when(rateLimiter.validateAsync(any(UUID.class)))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true)));

View File

@ -75,7 +75,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest;
import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext;
import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -107,6 +106,7 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -279,7 +279,7 @@ class ProfileControllerTest {
final BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY);
@ -295,7 +295,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@ -306,7 +306,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + UUID.randomUUID())
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@ -351,7 +351,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@ -365,7 +365,7 @@ class ProfileControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes()))
.get();
assertThat(response.getStatus()).isEqualTo(401);
@ -1140,7 +1140,7 @@ class ProfileControllerTest {
private static Stream<Arguments> testGetProfileWithExpiringProfileKeyCredential() {
return Stream.of(
Arguments.of(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))),
Arguments.of(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)))),
Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))))
);
@ -1184,7 +1184,7 @@ class ProfileControllerTest {
HexFormat.of().formatHex(credentialRequest.serialize())))
.queryParam("credentialType", "expiringProfileKey")
.request()
.headers(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY))))
.headers(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY))))
.get();
assertEquals(400, response.getStatus());