diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 362e5f4fd..b7cc75a9c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -812,7 +812,7 @@ public class WhisperServerService extends Application requestAccount, Optional accessKey, Optional targetAccount, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 182a5e7f0..b9e8caac4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -61,6 +61,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -206,7 +207,7 @@ public class KeysController { name = "Retry-After", description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed")) public PreKeyResponse getDeviceKeys(@Auth Optional auth, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Parameter(description="the account or phone-number identifier to retrieve keys for") @PathParam("identifier") ServiceIdentifier targetIdentifier, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 8d4f3e055..7b6407e03 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -8,6 +8,8 @@ import static com.codahale.metrics.MetricRegistry.name; import com.codahale.metrics.annotation.Timed; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.common.net.HttpHeaders; import com.google.protobuf.ByteString; import io.dropwizard.auth.Auth; @@ -52,6 +54,7 @@ import javax.ws.rs.DELETE; import javax.ws.rs.DefaultValue; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; +import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.NotFoundException; import javax.ws.rs.POST; import javax.ws.rs.PUT; @@ -67,13 +70,17 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage; +import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient; import org.signal.libsignal.protocol.util.Pair; +import org.signal.libsignal.zkgroup.ServerSecretParams; +import org.signal.libsignal.zkgroup.VerificationFailedException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; +import org.whispersystems.textsecuregcm.auth.GroupSendCredentialHeader; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -113,6 +120,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.Stories; @@ -148,6 +156,7 @@ public class MessageController { private final ReportSpamTokenProvider reportSpamTokenProvider; private final ClientReleaseManager clientReleaseManager; private final DynamicConfigurationManager dynamicConfigurationManager; + private final ServerSecretParams serverSecretParams; private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8; @@ -188,7 +197,8 @@ public class MessageController { Scheduler messageDeliveryScheduler, @Nonnull ReportSpamTokenProvider reportSpamTokenProvider, final ClientReleaseManager clientReleaseManager, - final DynamicConfigurationManager dynamicConfigurationManager) { + final DynamicConfigurationManager dynamicConfigurationManager, + final ServerSecretParams serverSecretParams) { this.rateLimiters = rateLimiters; this.messageByteLimitEstimator = messageByteLimitEstimator; this.messageSender = messageSender; @@ -202,6 +212,7 @@ public class MessageController { this.reportSpamTokenProvider = reportSpamTokenProvider; this.clientReleaseManager = clientReleaseManager; this.dynamicConfigurationManager = dynamicConfigurationManager; + this.serverSecretParams = serverSecretParams; } @Timed @@ -211,7 +222,7 @@ public class MessageController { @Produces(MediaType.APPLICATION_JSON) @FilterSpam public Response sendMessage(@Auth Optional source, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @PathParam("destination") ServiceIdentifier destinationIdentifier, @QueryParam("story") boolean isStory, @@ -372,6 +383,7 @@ public class MessageController { private Map buildRecipientMap( SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) { return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet()) + .switchIfEmpty(Flux.error(BadRequestException::new)) .map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue())) .flatMap( t -> Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(t.getT1())) @@ -406,7 +418,9 @@ public class MessageController { """) @ApiResponse(responseCode="200", description="Message was successfully sent to all recipients", useReturnTypeSchema=true) @ApiResponse(responseCode="400", description="The envelope specified delivery to the same recipient device multiple times") - @ApiResponse(responseCode="401", description="The message is not a story and the unauthorized access key is incorrect") + @ApiResponse( + responseCode="401", + description="The message is not a story and the unauthorized access key or group send credential is missing or incorrect") @ApiResponse( responseCode="404", description="The message is not a story and some of the recipient service IDs do not correspond to registered Signal users") @@ -416,10 +430,14 @@ public class MessageController { @ApiResponse( responseCode = "410", description = "Mismatched registration ids supplied for some recipient devices", content = @Content(schema = @Schema(implementation = AccountStaleDevices[].class))) - public Response sendMultiRecipientMessage( - @Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message") - @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, + @Deprecated + @Parameter(description="The bitwise xor of the unidentified access keys for every recipient of the message. Will be replaced with group send credentials") + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, + + @Parameter(description="A group send credential covering all (included and excluded) recipients of the message. Must not be combined with `Unidentified-Access-Key` or set on a story message.") + @HeaderParam(HeaderUtils.GROUP_SEND_CREDENTIAL) + @Nullable GroupSendCredentialHeader groupSendCredential, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @@ -436,14 +454,31 @@ public class MessageController { @QueryParam("story") boolean isStory, @Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal") @NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException { + if (groupSendCredential == null && accessKeys == null && !isStory) { + throw new NotAuthorizedException("A group send credential or unidentified access key is required for non-story messages"); + } + if (groupSendCredential != null) { + if (accessKeys != null) { + throw new BadRequestException("Only one of group send credential and unidentified access key may be provided"); + } else if (isStory) { + throw new BadRequestExcpetion("Stories should not provide a group send credential"); + } + } + + if (groupSendCredential != null) { + // Group send credentials are checked before we even attempt to resolve any accounts, since + // the lists of service IDs in the envelope are all that we need to check against + checkGroupSendCredential( + multiRecipientMessage.getRecipients().keySet(), multiRecipientMessage.getExcludedRecipients(), groupSendCredential); + } final Map recipients = buildRecipientMap(multiRecipientMessage, isStory); - // Stories will be checked by the client; we bypass access checks here for stories. - if (!isStory) { + // Access keys are checked against the UAK in the resolved accounts, so we have to check after resolving accounts above. + // Group send credentials are checked earlier; for stories, we don't check permissions at all because only clients check them + if (groupSendCredential == null && !isStory) { checkAccessKeys(accessKeys, recipients.values()); } - // We might filter out all the recipients of a story (if none exist). // In this case there is no error so we should just return 200 now. if (isStory) { @@ -556,12 +591,28 @@ public class MessageController { return Response.ok(new SendMultiRecipientMessageResponse(uuids404)).build(); } - private void checkAccessKeys(final CombinedUnidentifiedSenderAccessKeys accessKeys, final Collection destinations) { - // We should not have null access keys when checking access; bail out early. - if (accessKeys == null) { - throw new WebApplicationException(Status.UNAUTHORIZED); + private void checkGroupSendCredential( + final Collection recipients, + final Collection excludedRecipients, + final @NotNull GroupSendCredentialHeader groupSendCredential) { + try { + // A group send credential covers *every* group member except the sender. However, clients + // don't always want to actually send to every recipient in the same multi-send (most + // commonly because a new member needs an SKDM first, but also could be because the sender + // has blocked someone). So we check the group send credential against the combination of + // the actual recipients and the supplied list of "excluded" recipients, accounts the + // sender knows are part of the credential but doesn't want to send to right now. + groupSendCredential.presentation().verify( + Lists.newArrayList(Iterables.concat(recipients, excludedRecipients)), + serverSecretParams); + } catch (VerificationFailedException e) { + throw new NotAuthorizedException(e); } + } + private void checkAccessKeys( + final @NotNull CombinedUnidentifiedSenderAccessKeys accessKeys, + final Collection destinations) { final int keyLength = UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH; final byte[] combinedUnidentifiedAccessKeys = destinations.stream() .map(MultiRecipientDeliveryData::account) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 664882962..9e355ae5e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -94,6 +94,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.VersionedProfile; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ProfileHelper; import org.whispersystems.textsecuregcm.util.Util; @@ -226,7 +227,7 @@ public class ProfileController { @Path("/{identifier}/{version}") public VersionedProfileResponse getProfile( @Auth Optional auth, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("version") String version) @@ -246,7 +247,7 @@ public class ProfileController { @Path("/{identifier}/{version}/{credentialRequest}") public CredentialProfileResponse getProfile( @Auth Optional auth, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @PathParam("identifier") AciServiceIdentifier accountIdentifier, @PathParam("version") String version, @@ -276,7 +277,7 @@ public class ProfileController { @Path("/{identifier}") public BaseProfileResponse getUnversionedProfile( @Auth Optional auth, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, @PathParam("identifier") ServiceIdentifier identifier, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java index ba08c42ba..a3a6b7246 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/AciServiceIdentifier.java @@ -10,6 +10,8 @@ import java.util.Arrays; import java.util.HexFormat; import java.util.UUID; import io.swagger.v3.oas.annotations.media.Schema; + +import org.signal.libsignal.protocol.ServiceId; import org.whispersystems.textsecuregcm.util.UUIDUtil; /** @@ -51,6 +53,11 @@ public record AciServiceIdentifier(UUID uuid) implements ServiceIdentifier { return byteBuffer.array(); } + @Override + public ServiceId.Aci toLibsignal() { + return new ServiceId.Aci(uuid); + } + public static AciServiceIdentifier valueOf(final String string) { return new AciServiceIdentifier( UUID.fromString(string.startsWith(IDENTITY_TYPE.getStringPrefix()) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java index 2f184fd20..6ef3938aa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/PniServiceIdentifier.java @@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.identity; import io.swagger.v3.oas.annotations.media.Schema; + +import org.signal.libsignal.protocol.ServiceId; import org.whispersystems.textsecuregcm.util.UUIDUtil; import java.nio.ByteBuffer; import java.util.Arrays; @@ -51,6 +53,11 @@ public record PniServiceIdentifier(UUID uuid) implements ServiceIdentifier { return byteBuffer.array(); } + @Override + public ServiceId.Pni toLibsignal() { + return new ServiceId.Pni(uuid); + } + public static PniServiceIdentifier valueOf(final String string) { if (!string.startsWith(IDENTITY_TYPE.getStringPrefix())) { throw new IllegalArgumentException("PNI account identifier did not start with \"PNI:\" prefix"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java index ab52312a8..43c0d03b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/identity/ServiceIdentifier.java @@ -81,4 +81,6 @@ public interface ServiceIdentifier { } throw new IllegalArgumentException("unknown libsignal ServiceId type"); } + + ServiceId toLibsignal(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java index 89b9e7b3b..01ff4928a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java @@ -24,6 +24,10 @@ public final class HeaderUtils { public static final String TIMESTAMP_HEADER = "X-Signal-Timestamp"; + public static final String UNIDENTIFIED_ACCESS_KEY = "Unidentified-Access-Key"; + + public static final String GROUP_SEND_CREDENTIAL = "Group-Send-Credential"; + private HeaderUtils() { // utility class } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java index 8ce182f03..ef0a3de9b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java @@ -40,12 +40,12 @@ import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations; import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; -import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -198,7 +198,7 @@ class CertificateControllerTest { Response response = resources.getJerseyTest() .target("/v1/certificate/delivery") .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes())) .get(); assertEquals(response.getStatus(), 401); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 84dc3395f..062de0091 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -48,7 +48,6 @@ import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; @@ -73,6 +72,7 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; +import org.whispersystems.textsecuregcm.util.HeaderUtils; @ExtendWith(DropwizardExtensionsSupport.class) class KeysControllerTest { @@ -494,7 +494,7 @@ class KeysControllerTest { .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .queryParam("pq", "true") .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .get(PreKeyResponse.class); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); @@ -518,7 +518,7 @@ class KeysControllerTest { Response result = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .get(); assertThat(result).isNotNull(); @@ -530,7 +530,7 @@ class KeysControllerTest { Response response = resources.getJerseyTest() .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("9999".getBytes())) .get(); assertThat(response.getStatus()).isEqualTo(401); @@ -542,7 +542,7 @@ class KeysControllerTest { Response response = resources.getJerseyTest() .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, "$$$$$$$$$") + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, "$$$$$$$$$") .get(); assertThat(response.getStatus()).isEqualTo(401); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index cfaf12d29..87ac42283 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -81,10 +81,20 @@ import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.ArgumentSets; import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.ServiceId; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.util.Hex; +import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher; +import org.signal.libsignal.zkgroup.groups.GroupMasterKey; +import org.signal.libsignal.zkgroup.groups.GroupSecretParams; +import org.signal.libsignal.zkgroup.groups.UuidCiphertext; +import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential; +import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation; +import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse; +import org.signal.libsignal.zkgroup.ServerPublicParams; +import org.signal.libsignal.zkgroup.ServerSecretParams; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; @@ -125,6 +135,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UUIDUtil; @@ -137,15 +148,19 @@ import reactor.core.scheduler.Schedulers; class MessageControllerTest { private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; - private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); - private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); + private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); + private static final ServiceIdentifier SINGLE_DEVICE_ACI_ID = new AciServiceIdentifier(SINGLE_DEVICE_UUID); + private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); + private static final ServiceIdentifier SINGLE_DEVICE_PNI_ID = new PniServiceIdentifier(SINGLE_DEVICE_PNI); private static final byte SINGLE_DEVICE_ID1 = 1; private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final int SINGLE_DEVICE_PNI_REG_ID1 = 1111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); + private static final ServiceIdentifier MULTI_DEVICE_ACI_ID = new AciServiceIdentifier(MULTI_DEVICE_UUID); private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); + private static final ServiceIdentifier MULTI_DEVICE_PNI_ID = new PniServiceIdentifier(MULTI_DEVICE_PNI); private static final byte MULTI_DEVICE_ID1 = 1; private static final byte MULTI_DEVICE_ID2 = 2; private static final byte MULTI_DEVICE_ID3 = 3; @@ -157,6 +172,8 @@ class MessageControllerTest { private static final int MULTI_DEVICE_PNI_REG_ID3 = 4444; private static final UUID NONEXISTENT_UUID = UUID.randomUUID(); + private static final ServiceIdentifier NONEXISTENT_ACI_ID = new AciServiceIdentifier(NONEXISTENT_UUID); + private static final ServiceIdentifier NONEXISTENT_PNI_ID = new PniServiceIdentifier(NONEXISTENT_UUID); private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); @@ -178,6 +195,7 @@ class MessageControllerTest { private static final ExecutorService multiRecipientMessageExecutor = MoreExecutors.newDirectExecutorService(); private static final Scheduler messageDeliveryScheduler = Schedulers.newBoundedElastic(10, 10_000, "messageDelivery"); private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate(); private static final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) @@ -189,7 +207,8 @@ class MessageControllerTest { .addResource( new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, - messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager)) + messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager, + serverSecretParams)) .build(); @BeforeEach @@ -213,19 +232,19 @@ class MessageControllerTest { Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); - when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(Optional.of(singleDeviceAccount)); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(Optional.of(multiDeviceAccount)); - when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(Optional.of(multiDeviceAccount)); + when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_ACI_ID)).thenReturn(Optional.of(singleDeviceAccount)); + when(accountsManager.getByServiceIdentifier(SINGLE_DEVICE_PNI_ID)).thenReturn(Optional.of(singleDeviceAccount)); + when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_ACI_ID)).thenReturn(Optional.of(multiDeviceAccount)); + when(accountsManager.getByServiceIdentifier(MULTI_DEVICE_PNI_ID)).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(Optional.of(internationalAccount)); - when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); - when(accountsManager.getByServiceIdentifier(new PniServiceIdentifier(NONEXISTENT_UUID))).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifier(NONEXISTENT_ACI_ID)).thenReturn(Optional.empty()); + when(accountsManager.getByServiceIdentifier(NONEXISTENT_PNI_ID)).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(SINGLE_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); - when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(SINGLE_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); - when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(MULTI_DEVICE_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); - when(accountsManager.getByServiceIdentifierAsync(new PniServiceIdentifier(MULTI_DEVICE_PNI))).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(SINGLE_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(singleDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_ACI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); + when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount))); final DynamicInboundMessageByteLimitConfiguration inboundMessageByteLimitConfiguration = @@ -381,7 +400,7 @@ class MessageControllerTest { resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(SystemMapper.jsonMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), MediaType.APPLICATION_JSON_TYPE)); @@ -904,7 +923,7 @@ class MessageControllerTest { resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(new IncomingMessageList( List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true, System.currentTimeMillis()), @@ -965,7 +984,21 @@ class MessageControllerTest { bb.put(r.perRecipientKeyMaterial()); // key material (48 bytes) } + private static void writeMultiPayloadExcludedRecipient(final ByteBuffer bb, final ServiceIdentifier id, final boolean useExplicitIdentifier) { + if (useExplicitIdentifier) { + bb.put(id.toFixedWidthByteArray()); + } else { + bb.put(UUIDUtil.toBytes(id.uuid())); + } + + bb.put((byte) 0); + } + private static InputStream initializeMultiPayload(List recipients, byte[] buffer, final boolean explicitIdentifiers) { + return initializeMultiPayload(recipients, List.of(), buffer, explicitIdentifiers); + } + + private static InputStream initializeMultiPayload(List recipients, List excludedRecipients, byte[] buffer, final boolean explicitIdentifiers) { // initialize a binary payload according to our wire format ByteBuffer bb = ByteBuffer.wrap(buffer); bb.order(ByteOrder.BIG_ENDIAN); @@ -974,17 +1007,15 @@ class MessageControllerTest { bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte // count varint - int nRecip = recipients.size(); + int nRecip = recipients.size() + excludedRecipients.size(); while (nRecip > 127) { bb.put((byte) (nRecip & 0x7F | 0x80)); nRecip = nRecip >> 7; } bb.put((byte)(nRecip & 0x7F)); - Iterator it = recipients.iterator(); - while (it.hasNext()) { - writeMultiPayloadRecipient(bb, it.next(), explicitIdentifiers); - } + recipients.forEach(recipient -> writeMultiPayloadRecipient(bb, recipient, explicitIdentifiers)); + excludedRecipients.forEach(recipient -> writeMultiPayloadExcludedRecipient(bb, recipient, explicitIdentifiers)); // now write the actual message body (empty for now) bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here) @@ -1062,8 +1093,17 @@ class MessageControllerTest { // set up the entity to use in our PUT request Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); - // start building the request - Invocation.Builder bldr = resources + // build correct or incorrect access header + final String accessHeader; + if (authorize) { + final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); + accessHeader = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]); + } else { + accessHeader = "BBBBBBBBBBBBBBBBBBBBBB=="; + } + + // make the PUT request + Response response = resources .getJerseyTest() .target("/v1/messages/multi_recipient") .queryParam("online", true) @@ -1071,17 +1111,9 @@ class MessageControllerTest { .queryParam("story", isStory) .queryParam("urgent", urgent) .request() - .header(HttpHeaders.USER_AGENT, "FIXME"); - - // add access header if needed - if (authorize) { - final long count = destinations.keySet().stream().map(accountsManager::getByServiceIdentifier).filter(Optional::isPresent).count(); - String encodedBytes = Base64.getEncoder().encodeToString(count % 2 == 1 ? UNIDENTIFIED_ACCESS_BYTES : new byte[16]); - bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); - } - - // make the PUT request - Response response = bldr.put(entity); + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessHeader) + .put(entity); assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); verify(messageSender, @@ -1105,18 +1137,18 @@ class MessageControllerTest { private static Map> multiRecipientTargetMap() { return Map.of( - new AciServiceIdentifier(SINGLE_DEVICE_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - new PniServiceIdentifier(SINGLE_DEVICE_PNI), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), - new AciServiceIdentifier(MULTI_DEVICE_UUID), + SINGLE_DEVICE_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), + SINGLE_DEVICE_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1), + MULTI_DEVICE_ACI_ID, Map.of( MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2), - new PniServiceIdentifier(MULTI_DEVICE_PNI), + MULTI_DEVICE_PNI_ID, Map.of( MULTI_DEVICE_ID1, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_ID2, MULTI_DEVICE_PNI_REG_ID2), - new AciServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), - new PniServiceIdentifier(NONEXISTENT_UUID), Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) + NONEXISTENT_ACI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1), + NONEXISTENT_PNI_ID, Map.of(SINGLE_DEVICE_ID1, SINGLE_DEVICE_PNI_REG_ID1) ); } @@ -1137,16 +1169,16 @@ class MessageControllerTest { @SuppressWarnings("unused") private static ArgumentSets testMultiRecipientMessageNoPni() { final Map> targets = multiRecipientTargetMap(); - final Map> singleDeviceAci = submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID)); - final Map> multiDeviceAci = submap(targets, new AciServiceIdentifier(MULTI_DEVICE_UUID)); + final Map> singleDeviceAci = submap(targets, SINGLE_DEVICE_ACI_ID); + final Map> multiDeviceAci = submap(targets, MULTI_DEVICE_ACI_ID); final Map> bothAccountsAci = - submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new AciServiceIdentifier(MULTI_DEVICE_UUID)); + submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID); final Map> realAndFakeAci = submap( targets, - new AciServiceIdentifier(SINGLE_DEVICE_UUID), - new AciServiceIdentifier(MULTI_DEVICE_UUID), - new AciServiceIdentifier(NONEXISTENT_UUID)); + SINGLE_DEVICE_ACI_ID, + MULTI_DEVICE_ACI_ID, + NONEXISTENT_ACI_ID); final boolean auth = true; final boolean unauth = false; @@ -1186,18 +1218,18 @@ class MessageControllerTest { private static ArgumentSets testMultiRecipientMessagePni() { final Map> targets = multiRecipientTargetMap(); - final Map> singleDevicePni = submap(targets, new PniServiceIdentifier(SINGLE_DEVICE_PNI)); + final Map> singleDevicePni = submap(targets, SINGLE_DEVICE_PNI_ID); final Map> singleDeviceAciAndPni = submap( - targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(SINGLE_DEVICE_PNI)); - final Map> multiDevicePni = submap(targets, new PniServiceIdentifier(MULTI_DEVICE_PNI)); + targets, SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_PNI_ID); + final Map> multiDevicePni = submap(targets, MULTI_DEVICE_PNI_ID); final Map> bothAccountsMixed = - submap(targets, new AciServiceIdentifier(SINGLE_DEVICE_UUID), new PniServiceIdentifier(MULTI_DEVICE_PNI)); + submap(targets, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_PNI_ID); final Map> realAndFakeMixed = submap( targets, - new PniServiceIdentifier(SINGLE_DEVICE_PNI), - new AciServiceIdentifier(MULTI_DEVICE_UUID), - new PniServiceIdentifier(NONEXISTENT_UUID)); + SINGLE_DEVICE_PNI_ID, + MULTI_DEVICE_ACI_ID, + NONEXISTENT_PNI_ID); final boolean auth = true; final boolean unauth = false; @@ -1232,13 +1264,113 @@ class MessageControllerTest { .argumentsForNextParameter(false, true); // urgent } + @ParameterizedTest + @MethodSource + void testMultiRecipientMessageWithGroupSendCredential( + List includedRecipients, + List excludedRecipients, + int expectedStatus, + int expectedMessagesSent) throws Exception { + final List recipients = new ArrayList<>(); + includedRecipients.forEach( + serviceIdentifier -> multiRecipientTargetMap().get(serviceIdentifier).forEach( + (deviceId, registrationId) -> + recipients.add(new Recipient(serviceIdentifier, deviceId, registrationId, new byte[48])))); + + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + InputStream stream = initializeMultiPayload(recipients, excludedRecipients, buffer, true); + final AciServiceIdentifier senderId = new AciServiceIdentifier(UUID.randomUUID()); + + Response response = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("ts", 1663798405641L) + .queryParam("story", false) + .queryParam("urgent", false) + .request() + .header(HttpHeaders.USER_AGENT, "FIXME") + .header(HeaderUtils.GROUP_SEND_CREDENTIAL, validGroupSendCredentialHeader( + senderId, + List.of(senderId, SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID))) + .put(Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE)); + + assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedStatus))); + verify(messageSender, + exactly(expectedMessagesSent)) + .sendMessage( + any(), + any(), + argThat(env -> !env.hasSourceUuid() && !env.hasSourceDevice()), + eq(true)); + if (expectedStatus == 200) { + SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); + assertThat(smrmr.uuids404(), is(empty())); + } + } + + private static Stream testMultiRecipientMessageWithGroupSendCredential() { + return Stream.of( + // All members present in included or excluded recipients: success, deliver to included recipients only + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(), 200, 3), + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 200, 1), + Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 200, 2), + + // No included recipients: request is bad + Arguments.of(List.of(), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 400, 0), + + // Some recipients both included and excluded: request is bad + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID), 400, 0), + + // Included recipient not covered by credential: forbid + Arguments.of(List.of(NONEXISTENT_ACI_ID), List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0), + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(MULTI_DEVICE_ACI_ID), 401, 0), + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID, NONEXISTENT_ACI_ID), List.of(), 401, 0), + + // Excluded recipient not covered by credential: forbid + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID, MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID), 401, 0), + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, MULTI_DEVICE_ACI_ID), 401, 0), + Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(NONEXISTENT_ACI_ID, SINGLE_DEVICE_ACI_ID), 401, 0), + + // Some recipients not in included or excluded list: forbid + Arguments.of(List.of(SINGLE_DEVICE_ACI_ID), List.of(), 401, 0), + Arguments.of(List.of(MULTI_DEVICE_ACI_ID), List.of(), 401, 0), + + // Substituting a PNI for an ACI is not allowed + Arguments.of(List.of(SINGLE_DEVICE_PNI_ID, MULTI_DEVICE_ACI_ID), List.of(), 401, 0)); + } + + private String validGroupSendCredentialHeader(AciServiceIdentifier sender, List allGroupMembers) throws Exception { + final ServerPublicParams serverPublicParams = serverSecretParams.getPublicParams(); + final GroupMasterKey groupMasterKey = new GroupMasterKey(new byte[32]); + final GroupSecretParams groupSecretParams = GroupSecretParams.deriveFromMasterKey(groupMasterKey); + final ClientZkGroupCipher clientZkGroupCipher = new ClientZkGroupCipher(groupSecretParams); + + UuidCiphertext senderCiphertext = clientZkGroupCipher.encrypt(sender.toLibsignal()); + List groupCiphertexts = allGroupMembers.stream() + .map(ServiceIdentifier::toLibsignal) + .map(clientZkGroupCipher::encrypt) + .collect(Collectors.toList()); + GroupSendCredentialResponse credentialResponse = + GroupSendCredentialResponse.issueCredential(groupCiphertexts, senderCiphertext, serverSecretParams); + GroupSendCredential credential = + credentialResponse.receive( + allGroupMembers.stream().map(ServiceIdentifier::toLibsignal).collect(Collectors.toList()), + sender.toLibsignal(), + serverPublicParams, + groupSecretParams); + GroupSendCredentialPresentation presentation = credential.present(serverPublicParams); + return Base64.getEncoder().encodeToString(presentation.serialize()); + } + @ParameterizedTest @ValueSource(booleans = {true, false}) void testMultiRecipientRedisBombProtection(final boolean useExplicitIdentifier) throws Exception { final List recipients = List.of( - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), - new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); + new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), + new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), + new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48])); Response response = resources .getJerseyTest() @@ -1249,7 +1381,7 @@ class MessageControllerTest { .queryParam("urgent", false) .request() .header(HttpHeaders.USER_AGENT, "cluck cluck, i'm a parrot") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE)); checkBadMultiRecipientResponse(response, 400); @@ -1266,7 +1398,7 @@ class MessageControllerTest { .target(String.format("/v1/messages/%s", unknownUUID)) .queryParam("story", "true") .request() - .header(OptionalAccess.UNIDENTIFIED, accessBytes) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes) .put(Entity.entity(list, MediaType.APPLICATION_JSON_TYPE)); assertThat("200 masks unknown recipient", response.getStatus(), is(equalTo(200))); @@ -1278,13 +1410,13 @@ class MessageControllerTest { final Recipient r1; if (known) { - r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); + r1 = new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); } else { r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]); } - Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); - Recipient r3 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); + Recipient r2 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); + Recipient r3 = new Recipient(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]); List recipients = List.of(r1, r2, r3); @@ -1307,7 +1439,7 @@ class MessageControllerTest { .queryParam("story", story) .request() .header(HttpHeaders.USER_AGENT, "Test User Agent") - .header(OptionalAccess.UNIDENTIFIED, accessBytes); + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, accessBytes); // make the PUT request Response response = bldr.put(entity); @@ -1363,7 +1495,7 @@ class MessageControllerTest { .queryParam("urgent", true) .request() .header(HttpHeaders.USER_AGENT, "FIXME") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); // make the PUT request final Response response = invocationBuilder.put(entity); @@ -1381,8 +1513,8 @@ class MessageControllerTest { private static Stream sendMultiRecipientMessageMismatchedDevices() { return Stream.of( - Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), - Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); + Arguments.of(MULTI_DEVICE_ACI_ID), + Arguments.of(MULTI_DEVICE_PNI_ID)); } @ParameterizedTest @@ -1410,7 +1542,7 @@ class MessageControllerTest { .queryParam("urgent", true) .request() .header(HttpHeaders.USER_AGENT, "FIXME") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); // make the PUT request final Response response = invocationBuilder.put(entity); @@ -1429,8 +1561,8 @@ class MessageControllerTest { private static Stream sendMultiRecipientMessageStaleDevices() { return Stream.of( - Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID)), - Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI))); + Arguments.of(MULTI_DEVICE_ACI_ID), + Arguments.of(MULTI_DEVICE_PNI_ID)); } @ParameterizedTest @@ -1460,7 +1592,7 @@ class MessageControllerTest { .queryParam("urgent", true) .request() .header(HttpHeaders.USER_AGENT, "FIXME") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); doThrow(NotPushRegisteredException.class) .when(messageSender).sendMessage(any(), any(), any(), anyBoolean()); @@ -1473,13 +1605,13 @@ class MessageControllerTest { private static Stream sendMultiRecipientMessage404() { return Stream.of( - Arguments.of(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), - Arguments.of(new PniServiceIdentifier(MULTI_DEVICE_PNI), MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); + Arguments.of(MULTI_DEVICE_ACI_ID, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_REG_ID2), + Arguments.of(MULTI_DEVICE_PNI_ID, MULTI_DEVICE_PNI_REG_ID1, MULTI_DEVICE_PNI_REG_ID2)); } @Test void sendMultiRecipientMessageStoryRateLimited() { - final List recipients = List.of(new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); + final List recipients = List.of(new Recipient(SINGLE_DEVICE_ACI_ID, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48])); // initialize our binary payload and create an input stream byte[] buffer = new byte[2048]; // InputStream stream = initializeMultiPayload(recipientUUID, buffer); @@ -1498,7 +1630,7 @@ class MessageControllerTest { .queryParam("urgent", true) .request() .header(HttpHeaders.USER_AGENT, "FIXME") - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)); when(rateLimiter.validateAsync(any(UUID.class))) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofSeconds(77), true))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index eaeaa3b53..a7808ff21 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -75,7 +75,6 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequest; import org.signal.libsignal.zkgroup.profiles.ProfileKeyCredentialRequestContext; import org.signal.libsignal.zkgroup.profiles.ServerZkProfileOperations; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -107,6 +106,7 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @@ -279,7 +279,7 @@ class ProfileControllerTest { final BaseProfileResponse profile = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .get(BaseProfileResponse.class); assertThat(profile.getIdentityKey()).isEqualTo(ACCOUNT_TWO_IDENTITY_KEY); @@ -295,7 +295,7 @@ class ProfileControllerTest { final Response response = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) .get(); assertThat(response.getStatus()).isEqualTo(401); @@ -306,7 +306,7 @@ class ProfileControllerTest { final Response response = resources.getJerseyTest() .target("/v1/profile/" + UUID.randomUUID()) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .get(); assertThat(response.getStatus()).isEqualTo(401); @@ -351,7 +351,7 @@ class ProfileControllerTest { final Response response = resources.getJerseyTest() .target("/v1/profile/PNI:" + AuthHelper.VALID_PNI_TWO) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader(UNIDENTIFIED_ACCESS_KEY)) .get(); assertThat(response.getStatus()).isEqualTo(401); @@ -365,7 +365,7 @@ class ProfileControllerTest { final Response response = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) + .header(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, AuthHelper.getUnidentifiedAccessHeader("incorrect".getBytes())) .get(); assertThat(response.getStatus()).isEqualTo(401); @@ -1140,7 +1140,7 @@ class ProfileControllerTest { private static Stream testGetProfileWithExpiringProfileKeyCredential() { return Stream.of( - Arguments.of(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))), + Arguments.of(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))), Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)))), Arguments.of(new MultivaluedHashMap<>(Map.of("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO)))) ); @@ -1184,7 +1184,7 @@ class ProfileControllerTest { HexFormat.of().formatHex(credentialRequest.serialize()))) .queryParam("credentialType", "expiringProfileKey") .request() - .headers(new MultivaluedHashMap<>(Map.of(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))) + .headers(new MultivaluedHashMap<>(Map.of(HeaderUtils.UNIDENTIFIED_ACCESS_KEY, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_KEY)))) .get(); assertEquals(400, response.getStatus());