Replace MultiRecipientMessage parsing with libsignal's implementation
Co-authored-by: Jonathan Klabunde Tomer <jkt@signal.org>
This commit is contained in:
parent
f20d3043d6
commit
2ab3c97ee8
2
pom.xml
2
pom.xml
|
@ -284,7 +284,7 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.signal</groupId>
|
<groupId>org.signal</groupId>
|
||||||
<artifactId>libsignal-server</artifactId>
|
<artifactId>libsignal-server</artifactId>
|
||||||
<version>0.33.0</version>
|
<version>0.35.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.logging.log4j</groupId>
|
<groupId>org.apache.logging.log4j</groupId>
|
||||||
|
|
|
@ -25,13 +25,9 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||||
import java.security.MessageDigest;
|
import java.security.MessageDigest;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -42,9 +38,6 @@ import java.util.concurrent.CancellationException;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
import java.util.function.BiConsumer;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.function.Predicate;
|
import java.util.function.Predicate;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
|
@ -73,6 +66,9 @@ import javax.ws.rs.core.MediaType;
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
import javax.ws.rs.core.Response.Status;
|
import javax.ws.rs.core.Response.Status;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage.Recipient;
|
||||||
|
import org.signal.libsignal.protocol.util.Pair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.auth.Anonymous;
|
import org.whispersystems.textsecuregcm.auth.Anonymous;
|
||||||
|
@ -88,8 +84,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
|
||||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||||
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
|
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
|
||||||
|
@ -118,13 +112,13 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
|
||||||
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
|
||||||
import org.whispersystems.websocket.Stories;
|
import org.whispersystems.websocket.Stories;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.core.scheduler.Scheduler;
|
import reactor.core.scheduler.Scheduler;
|
||||||
|
import reactor.util.function.Tuples;
|
||||||
|
|
||||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||||
@Path("/v1/messages")
|
@Path("/v1/messages")
|
||||||
|
@ -134,7 +128,8 @@ public class MessageController {
|
||||||
private record MultiRecipientDeliveryData(
|
private record MultiRecipientDeliveryData(
|
||||||
ServiceIdentifier serviceIdentifier,
|
ServiceIdentifier serviceIdentifier,
|
||||||
Account account,
|
Account account,
|
||||||
Map<Byte, Recipient> perDeviceData) {
|
Recipient recipient,
|
||||||
|
Map<Byte, Short> deviceIdToRegistrationId) {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
|
||||||
|
@ -369,20 +364,22 @@ public class MessageController {
|
||||||
* Build mapping of service IDs to resolved accounts and device/registration IDs
|
* Build mapping of service IDs to resolved accounts and device/registration IDs
|
||||||
*/
|
*/
|
||||||
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
|
private Map<ServiceIdentifier, MultiRecipientDeliveryData> buildRecipientMap(
|
||||||
MultiRecipientMessage multiRecipientMessage, boolean isStory) {
|
SealedSenderMultiRecipientMessage multiRecipientMessage, boolean isStory) {
|
||||||
return Flux.fromArray(multiRecipientMessage.recipients())
|
return Flux.fromIterable(multiRecipientMessage.getRecipients().entrySet())
|
||||||
.groupBy(Recipient::uuid, multiRecipientMessage.recipients().length)
|
.map(e -> Tuples.of(ServiceIdentifier.fromLibsignal(e.getKey()), e.getValue()))
|
||||||
.flatMap(
|
.flatMap(
|
||||||
gf -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(gf.key()))
|
t -> Mono.justOrEmpty(accountsManager.getByServiceIdentifier(t.getT1()))
|
||||||
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
|
.switchIfEmpty(isStory ? Mono.empty() : Mono.error(NotFoundException::new))
|
||||||
.flatMap(
|
.map(
|
||||||
account ->
|
account ->
|
||||||
gf.collectMap(Recipient::deviceId)
|
new MultiRecipientDeliveryData(
|
||||||
.map(perRecipientData ->
|
t.getT1(),
|
||||||
new MultiRecipientDeliveryData(
|
account,
|
||||||
gf.key(),
|
t.getT2(),
|
||||||
account,
|
t.getT2().getDevicesAndRegistrationIds().collect(
|
||||||
perRecipientData))))
|
Collectors.toMap(Pair<Byte, Short>::first, Pair<Byte, Short>::second))))
|
||||||
|
// IllegalStateException is thrown by Collectors#toMap when we have multiple entries for the same device
|
||||||
|
.onErrorMap(e -> e instanceof IllegalStateException ? new BadRequestException() : e))
|
||||||
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
|
.collectMap(MultiRecipientDeliveryData::serviceIdentifier)
|
||||||
.block();
|
.block();
|
||||||
}
|
}
|
||||||
|
@ -429,8 +426,8 @@ public class MessageController {
|
||||||
|
|
||||||
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
|
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
|
||||||
@QueryParam("story") boolean isStory,
|
@QueryParam("story") boolean isStory,
|
||||||
@Parameter(description="The sealed-sender multi-recipient message payload")
|
@Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal")
|
||||||
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
|
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
|
||||||
|
|
||||||
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
|
final Map<ServiceIdentifier, MultiRecipientDeliveryData> recipients = buildRecipientMap(multiRecipientMessage, isStory);
|
||||||
|
|
||||||
|
@ -456,13 +453,13 @@ public class MessageController {
|
||||||
final Account account = recipient.account();
|
final Account account = recipient.account();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.perDeviceData().keySet(), Collections.emptySet());
|
DestinationDeviceValidator.validateCompleteDeviceList(account, recipient.deviceIdToRegistrationId().keySet(), Collections.emptySet());
|
||||||
|
|
||||||
DestinationDeviceValidator.validateRegistrationIds(
|
DestinationDeviceValidator.validateRegistrationIds(
|
||||||
account,
|
account,
|
||||||
recipient.perDeviceData().values(),
|
recipient.deviceIdToRegistrationId().entrySet(),
|
||||||
Recipient::deviceId,
|
Map.Entry<Byte, Short>::getKey,
|
||||||
Recipient::registrationId,
|
e -> Integer.valueOf(e.getValue()),
|
||||||
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
|
recipient.serviceIdentifier().identityType() == IdentityType.PNI);
|
||||||
} catch (MismatchedDevicesException e) {
|
} catch (MismatchedDevicesException e) {
|
||||||
accountMismatchedDevices.add(
|
accountMismatchedDevices.add(
|
||||||
|
@ -500,17 +497,19 @@ public class MessageController {
|
||||||
CompletableFuture.allOf(
|
CompletableFuture.allOf(
|
||||||
recipients.values().stream()
|
recipients.values().stream()
|
||||||
.flatMap(recipientData ->
|
.flatMap(recipientData ->
|
||||||
recipientData.perDeviceData().values().stream().map(
|
recipientData.deviceIdToRegistrationId().keySet().stream().map(
|
||||||
recipient -> CompletableFuture.runAsync(
|
deviceId ->CompletableFuture.runAsync(
|
||||||
() -> {
|
() -> {
|
||||||
final Account destinationAccount = recipientData.account();
|
final Account destinationAccount = recipientData.account();
|
||||||
|
final byte[] payload = multiRecipientMessage.messageForRecipient(recipientData.recipient());
|
||||||
|
|
||||||
// we asserted this must exist in validateCompleteDeviceList
|
// we asserted this must exist in validateCompleteDeviceList
|
||||||
final Device destinationDevice = destinationAccount.getDevice(recipient.deviceId()).orElseThrow();
|
final Device destinationDevice = destinationAccount.getDevice(deviceId).orElseThrow();
|
||||||
try {
|
try {
|
||||||
sentMessageCounter.increment();
|
sentMessageCounter.increment();
|
||||||
sendCommonPayloadMessage(
|
sendCommonPayloadMessage(
|
||||||
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
|
destinationAccount, destinationDevice, recipientData.serviceIdentifier(), timestamp, online,
|
||||||
isStory, isUrgent, recipient, multiRecipientMessage.commonPayload());
|
isStory, isUrgent, payload);
|
||||||
} catch (NoSuchUserException e) {
|
} catch (NoSuchUserException e) {
|
||||||
// this should never happen, because we already asserted the device is present and enabled
|
// this should never happen, because we already asserted the device is present and enabled
|
||||||
Metrics.counter(
|
Metrics.counter(
|
||||||
|
@ -740,17 +739,10 @@ public class MessageController {
|
||||||
boolean online,
|
boolean online,
|
||||||
boolean story,
|
boolean story,
|
||||||
boolean urgent,
|
boolean urgent,
|
||||||
Recipient recipient,
|
byte[] payload) throws NoSuchUserException {
|
||||||
byte[] commonPayload) throws NoSuchUserException {
|
|
||||||
try {
|
try {
|
||||||
Envelope.Builder messageBuilder = Envelope.newBuilder();
|
Envelope.Builder messageBuilder = Envelope.newBuilder();
|
||||||
long serverTimestamp = System.currentTimeMillis();
|
long serverTimestamp = System.currentTimeMillis();
|
||||||
byte[] recipientKeyMaterial = recipient.perRecipientKeyMaterial();
|
|
||||||
|
|
||||||
byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length];
|
|
||||||
payload[0] = MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER;
|
|
||||||
System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length);
|
|
||||||
System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, commonPayload.length);
|
|
||||||
|
|
||||||
messageBuilder
|
messageBuilder
|
||||||
.setType(Type.UNIDENTIFIED_SENDER)
|
.setType(Type.UNIDENTIFIED_SENDER)
|
||||||
|
|
|
@ -1,97 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2021 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Objects;
|
|
||||||
import javax.validation.Valid;
|
|
||||||
import javax.validation.constraints.AssertTrue;
|
|
||||||
import javax.validation.constraints.Max;
|
|
||||||
import javax.validation.constraints.Min;
|
|
||||||
import javax.validation.constraints.NotNull;
|
|
||||||
import javax.validation.constraints.Size;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
|
||||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
|
||||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
|
||||||
|
|
||||||
import io.micrometer.core.instrument.Counter;
|
|
||||||
import io.micrometer.core.instrument.Metrics;
|
|
||||||
import org.whispersystems.textsecuregcm.util.ServiceIdentifierAdapter;
|
|
||||||
|
|
||||||
public record MultiRecipientMessage(
|
|
||||||
@NotNull @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) @Valid Recipient[] recipients,
|
|
||||||
@NotNull @Size(min = 32) byte[] commonPayload) {
|
|
||||||
|
|
||||||
private static final Counter REJECT_DUPLICATE_RECIPIENT_COUNTER =
|
|
||||||
Metrics.counter(
|
|
||||||
name(MessageController.class, "rejectDuplicateRecipients"),
|
|
||||||
"multiRecipient", "false");
|
|
||||||
|
|
||||||
public record Recipient(@NotNull
|
|
||||||
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
|
|
||||||
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
|
|
||||||
ServiceIdentifier uuid,
|
|
||||||
@Min(1) byte deviceId,
|
|
||||||
@Min(0) @Max(65535) int registrationId,
|
|
||||||
@Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(final Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (o == null || getClass() != o.getClass())
|
|
||||||
return false;
|
|
||||||
Recipient recipient = (Recipient) o;
|
|
||||||
return deviceId == recipient.deviceId && registrationId == recipient.registrationId && uuid.equals(recipient.uuid)
|
|
||||||
&& Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result = Objects.hash(uuid, deviceId, registrationId);
|
|
||||||
result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) {
|
|
||||||
this.recipients = recipients;
|
|
||||||
this.commonPayload = commonPayload;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AssertTrue
|
|
||||||
public boolean hasNoDuplicateRecipients() {
|
|
||||||
boolean valid =
|
|
||||||
Arrays.stream(recipients).map(r -> new Pair<>(r.uuid(), r.deviceId())).distinct().count() == recipients.length;
|
|
||||||
if (!valid) {
|
|
||||||
REJECT_DUPLICATE_RECIPIENT_COUNTER.increment();
|
|
||||||
}
|
|
||||||
return valid;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(final Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (o == null || getClass() != o.getClass())
|
|
||||||
return false;
|
|
||||||
MultiRecipientMessage that = (MultiRecipientMessage) o;
|
|
||||||
return Arrays.equals(recipients, that.recipients) && Arrays.equals(commonPayload, that.commonPayload);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result = Arrays.hashCode(recipients);
|
|
||||||
result = 31 * result + Arrays.hashCode(commonPayload);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.identity;
|
||||||
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import org.signal.libsignal.protocol.ServiceId;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A "service identifier" is a tuple of a UUID and identity type that identifies an account and identity within the
|
* A "service identifier" is a tuple of a UUID and identity type that identifies an account and identity within the
|
||||||
|
@ -70,4 +71,14 @@ public interface ServiceIdentifier {
|
||||||
return PniServiceIdentifier.fromBytes(bytes);
|
return PniServiceIdentifier.fromBytes(bytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ServiceIdentifier fromLibsignal(final ServiceId libsignalServiceId) {
|
||||||
|
if (libsignalServiceId instanceof ServiceId.Aci) {
|
||||||
|
return new AciServiceIdentifier(libsignalServiceId.getRawUUID());
|
||||||
|
}
|
||||||
|
if (libsignalServiceId instanceof ServiceId.Pni) {
|
||||||
|
return new PniServiceIdentifier(libsignalServiceId.getRawUUID());
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("unknown libsignal ServiceId type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.providers;
|
package org.whispersystems.textsecuregcm.providers;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
|
||||||
import io.dropwizard.util.DataSizeUnit;
|
import io.dropwizard.util.DataSizeUnit;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -19,150 +18,36 @@ import javax.ws.rs.core.MultivaluedMap;
|
||||||
import javax.ws.rs.core.NoContentException;
|
import javax.ws.rs.core.NoContentException;
|
||||||
import javax.ws.rs.ext.MessageBodyReader;
|
import javax.ws.rs.ext.MessageBodyReader;
|
||||||
import javax.ws.rs.ext.Provider;
|
import javax.ws.rs.ext.Provider;
|
||||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||||
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
|
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||||
|
import org.signal.libsignal.protocol.InvalidVersionException;
|
||||||
|
|
||||||
@Provider
|
@Provider
|
||||||
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
|
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
|
||||||
public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRecipientMessage> {
|
public class MultiRecipientMessageProvider implements MessageBodyReader<SealedSenderMultiRecipientMessage> {
|
||||||
|
|
||||||
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
|
public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
|
||||||
public static final int MAX_RECIPIENT_COUNT = 5000;
|
public static final int MAX_RECIPIENT_COUNT = 5000;
|
||||||
public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256));
|
public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256));
|
||||||
|
|
||||||
public static final byte AMBIGUOUS_ID_VERSION_IDENTIFIER = 0x22;
|
|
||||||
public static final byte EXPLICIT_ID_VERSION_IDENTIFIER = 0x23;
|
|
||||||
|
|
||||||
private enum Version {
|
|
||||||
AMBIGUOUS_ID(AMBIGUOUS_ID_VERSION_IDENTIFIER),
|
|
||||||
EXPLICIT_ID(EXPLICIT_ID_VERSION_IDENTIFIER);
|
|
||||||
|
|
||||||
private final byte identifier;
|
|
||||||
|
|
||||||
Version(final byte identifier) {
|
|
||||||
this.identifier = identifier;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Version forVersionByte(final byte versionByte) {
|
|
||||||
for (final Version version : values()) {
|
|
||||||
if (version.identifier == versionByte) {
|
|
||||||
return version;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new IllegalArgumentException("Unrecognized version byte: " + versionByte);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
|
public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
|
||||||
return MEDIA_TYPE.equals(mediaType.toString()) && MultiRecipientMessage.class.isAssignableFrom(type);
|
return MEDIA_TYPE.equals(mediaType.toString()) && SealedSenderMultiRecipientMessage.class.isAssignableFrom(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MultiRecipientMessage readFrom(Class<MultiRecipientMessage> type, Type genericType, Annotation[] annotations,
|
public SealedSenderMultiRecipientMessage readFrom(Class<SealedSenderMultiRecipientMessage> type, Type genericType, Annotation[] annotations,
|
||||||
MediaType mediaType, MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
|
MediaType mediaType, MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
|
||||||
throws IOException, WebApplicationException {
|
throws IOException, WebApplicationException {
|
||||||
int versionByte = entityStream.read();
|
byte[] fullMessage = entityStream.readNBytes(MAX_MESSAGE_SIZE + MAX_RECIPIENT_COUNT * 100);
|
||||||
if (versionByte == -1) {
|
if (fullMessage.length == 0) {
|
||||||
throw new NoContentException("Empty body not allowed");
|
throw new NoContentException("Empty body not allowed");
|
||||||
}
|
}
|
||||||
|
|
||||||
final Version version;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
version = Version.forVersionByte((byte) versionByte);
|
return SealedSenderMultiRecipientMessage.parse(fullMessage);
|
||||||
} catch (final IllegalArgumentException e) {
|
} catch (InvalidMessageException | InvalidVersionException e) {
|
||||||
throw new BadRequestException("Unsupported version");
|
throw new BadRequestException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
long count = readVarint(entityStream);
|
|
||||||
if (count > MAX_RECIPIENT_COUNT) {
|
|
||||||
throw new BadRequestException("Maximum recipient count exceeded");
|
|
||||||
}
|
|
||||||
MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)];
|
|
||||||
for (int i = 0; i < Math.toIntExact(count); i++) {
|
|
||||||
ServiceIdentifier identifier = readIdentifier(entityStream, version);
|
|
||||||
final byte deviceId;
|
|
||||||
{
|
|
||||||
long deviceIdLong = readVarint(entityStream);
|
|
||||||
if (deviceIdLong > Byte.MAX_VALUE) {
|
|
||||||
throw new BadRequestException("Invalid device ID");
|
|
||||||
}
|
|
||||||
deviceId = (byte) deviceIdLong;
|
|
||||||
}
|
|
||||||
int registrationId = readU16(entityStream);
|
|
||||||
byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
|
|
||||||
if (perRecipientKeyMaterial.length != 48) {
|
|
||||||
throw new IOException("Failed to read expected number of key material bytes for a recipient");
|
|
||||||
}
|
|
||||||
recipients[i] = new MultiRecipientMessage.Recipient(identifier, deviceId, registrationId, perRecipientKeyMaterial);
|
|
||||||
}
|
|
||||||
|
|
||||||
// caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than
|
|
||||||
// this it'll return an error back. We just need to limit how many we'll accept here.
|
|
||||||
byte[] commonPayload = entityStream.readNBytes(MAX_MESSAGE_SIZE);
|
|
||||||
if (commonPayload.length < 32) {
|
|
||||||
throw new IOException("Failed to read expected number of common key material bytes");
|
|
||||||
}
|
|
||||||
return new MultiRecipientMessage(recipients, commonPayload);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reads a service identifier from the given stream.
|
|
||||||
*/
|
|
||||||
private ServiceIdentifier readIdentifier(final InputStream stream, final Version version) throws IOException {
|
|
||||||
final byte[] uuidBytes = switch (version) {
|
|
||||||
case AMBIGUOUS_ID -> stream.readNBytes(16);
|
|
||||||
case EXPLICIT_ID -> stream.readNBytes(17);
|
|
||||||
};
|
|
||||||
|
|
||||||
return ServiceIdentifier.fromBytes(uuidBytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
|
|
||||||
* {@code IOException} is thrown if the stream ends before we finish reading the varint.
|
|
||||||
*
|
|
||||||
* @return the varint value
|
|
||||||
*/
|
|
||||||
@VisibleForTesting
|
|
||||||
public static long readVarint(InputStream stream) throws IOException, WebApplicationException {
|
|
||||||
boolean hasMore = true;
|
|
||||||
int currentOffset = 0;
|
|
||||||
long result = 0;
|
|
||||||
while (hasMore) {
|
|
||||||
if (currentOffset >= 64) {
|
|
||||||
throw new BadRequestException("varint is too large");
|
|
||||||
}
|
|
||||||
int b = stream.read();
|
|
||||||
if (b == -1) {
|
|
||||||
throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
|
|
||||||
}
|
|
||||||
if (currentOffset == 63 && (b & 0xFE) != 0) {
|
|
||||||
throw new BadRequestException("varint is too large");
|
|
||||||
}
|
|
||||||
hasMore = (b & 0x80) != 0;
|
|
||||||
result |= (b & 0x7FL) << currentOffset;
|
|
||||||
currentOffset += 7;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Reads two bytes with most significant byte first. Treats the value as unsigned so the range returned is
|
|
||||||
* {@code [0, 65535]}.
|
|
||||||
*/
|
|
||||||
@VisibleForTesting
|
|
||||||
static int readU16(InputStream stream) throws IOException {
|
|
||||||
int b1 = stream.read();
|
|
||||||
if (b1 == -1) {
|
|
||||||
throw new IOException("Missing byte 1 of U16");
|
|
||||||
}
|
|
||||||
int b2 = stream.read();
|
|
||||||
if (b2 == -1) {
|
|
||||||
throw new IOException("Missing byte 2 of U16");
|
|
||||||
}
|
|
||||||
return (b1 << 8) | b2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,32 +4,14 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.util;
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
import static com.google.common.base.Objects.equal;
|
import java.util.Map;
|
||||||
|
|
||||||
public class Pair<T1, T2> {
|
public record Pair<T1, T2>(T1 first, T2 second) {
|
||||||
private final T1 v1;
|
public Pair(org.signal.libsignal.protocol.util.Pair<T1, T2> p) {
|
||||||
private final T2 v2;
|
this(p.first(), p.second());
|
||||||
|
|
||||||
public Pair(T1 v1, T2 v2) {
|
|
||||||
this.v1 = v1;
|
|
||||||
this.v2 = v2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public T1 first() {
|
public Pair(Map.Entry<T1, T2> e) {
|
||||||
return v1;
|
this(e.getKey(), e.getValue());
|
||||||
}
|
|
||||||
|
|
||||||
public T2 second() {
|
|
||||||
return v2;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
return o instanceof Pair &&
|
|
||||||
equal(((Pair<?, ?>) o).first(), first()) &&
|
|
||||||
equal(((Pair<?, ?>) o).second(), second());
|
|
||||||
}
|
|
||||||
|
|
||||||
public int hashCode() {
|
|
||||||
return first().hashCode() ^ second().hashCode();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ package org.whispersystems.textsecuregcm.controllers;
|
||||||
import static org.hamcrest.CoreMatchers.equalTo;
|
import static org.hamcrest.CoreMatchers.equalTo;
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
import static org.hamcrest.CoreMatchers.not;
|
import static org.hamcrest.CoreMatchers.not;
|
||||||
import static org.hamcrest.collection.IsEmptyCollection.empty;
|
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
|
import static org.hamcrest.collection.IsEmptyCollection.empty;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
|
@ -19,7 +19,6 @@ import static org.mockito.ArgumentMatchers.argThat;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.anyBoolean;
|
import static org.mockito.Mockito.anyBoolean;
|
||||||
import static org.mockito.Mockito.anyString;
|
import static org.mockito.Mockito.anyString;
|
||||||
import static org.mockito.Mockito.atLeastOnce;
|
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
|
@ -55,7 +54,6 @@ import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
@ -78,7 +76,6 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.ArgumentsSources;
|
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
|
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
|
||||||
|
@ -981,9 +978,7 @@ class MessageControllerTest {
|
||||||
bb.order(ByteOrder.BIG_ENDIAN);
|
bb.order(ByteOrder.BIG_ENDIAN);
|
||||||
|
|
||||||
// first write the header
|
// first write the header
|
||||||
bb.put(explicitIdentifiers
|
bb.put(explicitIdentifiers ? (byte) 0x23 : (byte) 0x22); // version byte
|
||||||
? MultiRecipientMessageProvider.EXPLICIT_ID_VERSION_IDENTIFIER
|
|
||||||
: MultiRecipientMessageProvider.AMBIGUOUS_ID_VERSION_IDENTIFIER); // version byte
|
|
||||||
|
|
||||||
// count varint
|
// count varint
|
||||||
int nRecip = recipients.size();
|
int nRecip = recipients.size();
|
||||||
|
@ -1258,7 +1253,7 @@ class MessageControllerTest {
|
||||||
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
|
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
|
||||||
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
|
.put(Entity.entity(initializeMultiPayload(recipients, new byte[2048], useExplicitIdentifier), MultiRecipientMessageProvider.MEDIA_TYPE));
|
||||||
|
|
||||||
checkBadMultiRecipientResponse(response, 422);
|
checkBadMultiRecipientResponse(response, 400);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2021 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.providers;
|
|
||||||
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
|
||||||
import static org.junit.jupiter.params.provider.Arguments.arguments;
|
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
|
||||||
import java.util.stream.Stream;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
|
||||||
|
|
||||||
public class MultiRecipientMessageProviderTest {
|
|
||||||
|
|
||||||
static byte[] createTwoByteArray(int b1, int b2) {
|
|
||||||
return new byte[]{(byte) b1, (byte) b2};
|
|
||||||
}
|
|
||||||
|
|
||||||
static Stream<Arguments> readU16TestCases() {
|
|
||||||
return Stream.of(
|
|
||||||
arguments(0xFFFE, createTwoByteArray(0xFF, 0xFE)),
|
|
||||||
arguments(0x0001, createTwoByteArray(0x00, 0x01)),
|
|
||||||
arguments(0xBEEF, createTwoByteArray(0xBE, 0xEF)),
|
|
||||||
arguments(0xFFFF, createTwoByteArray(0xFF, 0xFF)),
|
|
||||||
arguments(0x0000, createTwoByteArray(0x00, 0x00)),
|
|
||||||
arguments(0xF080, createTwoByteArray(0xF0, 0x80))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@MethodSource("readU16TestCases")
|
|
||||||
void testReadU16(int expectedValue, byte[] input) throws Exception {
|
|
||||||
try (final ByteArrayInputStream stream = new ByteArrayInputStream(input)) {
|
|
||||||
assertThat(MultiRecipientMessageProvider.readU16(stream)).isEqualTo(expectedValue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue