From c448c37cc9fbd938c500c4fc11bb35caf4684b52 Mon Sep 17 00:00:00 2001 From: Ehren Kret Date: Fri, 23 Apr 2021 16:10:17 -0500 Subject: [PATCH] Add logic to handle sending a common payload to multiple recipients --- .../textsecuregcm/WhisperServerService.java | 5 +- .../CombinedUnidentifiedSenderAccessKeys.java | 31 ++++ .../controllers/MessageController.java | 158 ++++++++++++++++-- .../entities/MultiRecipientMessage.java | 67 ++++++++ .../entities/SendMessageResponse.java | 8 + .../MultiRecipientMessageProvider.java | 129 ++++++++++++++ 6 files changed, 385 insertions(+), 13 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/CombinedUnidentifiedSenderAccessKeys.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index b9a2dc379..75d214cd6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -117,6 +117,7 @@ import org.whispersystems.textsecuregcm.metrics.NstatCounters; import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.TrafficSource; +import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck; import org.whispersystems.textsecuregcm.push.APNSender; @@ -477,14 +478,12 @@ public class WhisperServerService extends Application(ImmutableMap.of(Account.class, accountAuthFilter, DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter))); environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))); - environment.jersey().register(new TimestampResponseFilter()); - environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, usernamesManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator, verifyExperimentEnrollmentManager)); environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices())); environment.jersey().register(new DirectoryController(directoryCredentialsGenerator)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/CombinedUnidentifiedSenderAccessKeys.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CombinedUnidentifiedSenderAccessKeys.java new file mode 100644 index 000000000..9466692b7 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CombinedUnidentifiedSenderAccessKeys.java @@ -0,0 +1,31 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.io.IOException; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; +import org.whispersystems.textsecuregcm.util.Base64; + +public class CombinedUnidentifiedSenderAccessKeys { + private final byte[] combinedUnidentifiedSenderAccessKeys; + + public CombinedUnidentifiedSenderAccessKeys(String header) { + try { + this.combinedUnidentifiedSenderAccessKeys = Base64.decode(header); + if (this.combinedUnidentifiedSenderAccessKeys == null || this.combinedUnidentifiedSenderAccessKeys.length != 16) { + throw new WebApplicationException("Invalid combined unidentified sender access keys", Status.UNAUTHORIZED); + } + } catch (IOException e) { + throw new WebApplicationException(e, Response.Status.UNAUTHORIZED); + } + } + + public byte[] getAccessKeys() { + return combinedUnidentifiedSenderAccessKeys; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 7d2655ab8..dab3327e2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -20,17 +20,24 @@ import io.lettuce.core.ScriptOutputType; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import java.io.IOException; +import java.security.MessageDigest; import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.stream.Collectors; import javax.validation.Valid; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; @@ -40,26 +47,33 @@ import javax.ws.rs.PUT; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.Anonymous; +import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; +import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; @@ -89,6 +103,7 @@ public class MessageController { private final Meter identifiedMeter = metricRegistry.meter(name(getClass(), "delivery", "identified" )); private final Meter rejectOver256kibMessageMeter = metricRegistry.meter(name(getClass(), "rejectOver256kibMessage")); private final Timer sendMessageInternalTimer = metricRegistry.timer(name(getClass(), "sendMessageInternal")); + private final Timer sendCommonMessageInternalTimer = metricRegistry.timer(name(getClass(), "sendCommonMessageInternal")); private final Histogram outgoingMessageListSizeHistogram = metricRegistry.histogram(name(getClass(), "outgoingMessageListSize")); private final RateLimiters rateLimiters; @@ -295,6 +310,99 @@ public class MessageController { } } + @Timed + @Path("/multi_recipient") + @PUT + @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) + @Produces(MediaType.APPLICATION_JSON) + public Response sendMultiRecipientMessage( + @HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys, + @HeaderParam("User-Agent") String userAgent, + @HeaderParam("X-Forwarded-For") String forwardedFor, + @QueryParam("online") boolean online, + @QueryParam("ts") long timestamp, + @Valid MultiRecipientMessage multiRecipientMessage) { + + unidentifiedMeter.mark(multiRecipientMessage.getRecipients().length); + + Map uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) + .map(Recipient::getUuid) + .distinct() + .collect(Collectors.toMap(Function.identity(), uuid -> { + Optional account = accountsManager.get(uuid); + if (account.isEmpty()) { + throw new WebApplicationException(Status.NOT_FOUND); + } + return account.get(); + })); + checkAccessKeys(accessKeys, uuidToAccountMap); + + try { + for (Account account : uuidToAccountMap.values()) { + Set deviceIds = Arrays.stream(multiRecipientMessage.getRecipients()) + .filter(recipient -> recipient.getUuid().equals(account.getUuid())) + .map(Recipient::getDeviceId) + .collect(Collectors.toSet()); + validateCompleteDeviceList(account, deviceIds, false); + } + + List tags = List.of( + UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)), + Tag.of(SENDER_TYPE_TAG_NAME, "unidentified")); + List uuids404 = new ArrayList<>(); + for (Recipient recipient : multiRecipientMessage.getRecipients()) { + + Account destinationAccount = uuidToAccountMap.get(recipient.getUuid()); + // we asserted this must be true in validateCompleteDeviceList + //noinspection OptionalGetWithoutIsPresent + Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).get(); + Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); + try { + sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, + multiRecipientMessage.getCommonPayload()); + } catch (NoSuchUserException e) { + uuids404.add(destinationAccount.getUuid()); + } + } + return Response.ok(new SendMessageResponse(uuids404)).build(); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response + .status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices())) + .build()); + } + } + + private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map uuidToAccountMap) { + AtomicBoolean throwUnauthorized = new AtomicBoolean(false); + byte[] empty = new byte[16]; + byte[] combinedUnknownAccessKeys = uuidToAccountMap.values().stream() + .map(Account::getUnidentifiedAccessKey) + .map(accessKey -> { + if (accessKey.isEmpty()) { + throwUnauthorized.set(true); + return empty; + } + return accessKey.get(); + }) + .reduce(new byte[16], (bytes, bytes2) -> { + if (bytes.length != bytes2.length) { + throwUnauthorized.set(true); + return bytes; + } + for (int i = 0; i < bytes.length; i++) { + bytes[i] ^= bytes2[i]; + } + return bytes; + }); + if (throwUnauthorized.get() + || !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) { + throw new WebApplicationException(Status.UNAUTHORIZED); + } + } + private Response declineDelivery(final IncomingMessageList messages, final Account source, final Account destination) { Metrics.counter(DECLINED_DELIVERY_COUNTER, SENDER_COUNTRY_TAG_NAME, Util.getCountryCode(source.getNumber())).increment(); @@ -464,6 +572,34 @@ public class MessageController { } } + private void sendMessage(Account destinationAccount, Device destinationDevice, long timestamp, boolean online, + Recipient recipient, byte[] commonPayload) throws NoSuchUserException { + try (final Timer.Context ignored = sendCommonMessageInternalTimer.time()) { + Envelope.Builder messageBuilder = Envelope.newBuilder(); + long serverTimestamp = System.currentTimeMillis(); + byte[] recipientKeyMaterial = recipient.getPerRecipientKeyMaterial(); + + byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length]; + payload[0] = MultiRecipientMessageProvider.VERSION; + System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length); + System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, payload.length); + + messageBuilder + .setType(Type.UNIDENTIFIED_SENDER) + .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) + .setServerTimestamp(serverTimestamp) + .setContent(ByteString.copyFrom(payload)); + + messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); + } catch (NotPushRegisteredException e) { + if (destinationDevice.isMaster()) { + throw new NoSuchUserException(e); + } else { + logger.debug("Not registered", e); + } + } + } + private void validateRegistrationIds(Account account, List messages) throws StaleDevicesException { @@ -485,22 +621,24 @@ public class MessageController { } } + private void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage) + throws MismatchedDevicesException { + Set messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()); + validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage); + } + + private void validateCompleteDeviceList(Account account, - List messages, + Set messageDeviceIds, boolean isSyncMessage) throws MismatchedDevicesException { - Set messageDeviceIds = new HashSet<>(); Set accountDeviceIds = new HashSet<>(); List missingDeviceIds = new LinkedList<>(); List extraDeviceIds = new LinkedList<>(); - for (IncomingMessage message : messages) { - messageDeviceIds.add(message.getDestinationDeviceId()); - } - - for (Device device : account.getDevices()) { + for (Device device : account.getDevices()) { if (device.isEnabled() && !(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId())) { @@ -512,9 +650,9 @@ public class MessageController { } } - for (IncomingMessage message : messages) { - if (!accountDeviceIds.contains(message.getDestinationDeviceId())) { - extraDeviceIds.add(message.getDestinationDeviceId()); + for (Long deviceId : messageDeviceIds) { + if (!accountDeviceIds.contains(deviceId)) { + extraDeviceIds.add(deviceId); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java new file mode 100644 index 000000000..73ac76104 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import java.util.UUID; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Size; +import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; + +public class MultiRecipientMessage { + + public static class Recipient { + + @NotNull + private final UUID uuid; + + @Min(1) + private final long deviceId; + + @Size(min = 48, max = 48) + @NotNull + private final byte[] perRecipientKeyMaterial; + + public Recipient(UUID uuid, long deviceId, byte[] perRecipientKeyMaterial) { + this.uuid = uuid; + this.deviceId = deviceId; + this.perRecipientKeyMaterial = perRecipientKeyMaterial; + } + + public UUID getUuid() { + return uuid; + } + + public long getDeviceId() { + return deviceId; + } + + public byte[] getPerRecipientKeyMaterial() { + return perRecipientKeyMaterial; + } + } + + @NotNull + @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT) + private final Recipient[] recipients; + + @NotNull + @Min(32) + private final byte[] commonPayload; + + public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) { + this.recipients = recipients; + this.commonPayload = commonPayload; + } + + public Recipient[] getRecipients() { + return recipients; + } + + public byte[] getCommonPayload() { + return commonPayload; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMessageResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMessageResponse.java index f578fb1d6..025441c57 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMessageResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMessageResponse.java @@ -6,16 +6,24 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; +import java.util.UUID; public class SendMessageResponse { @JsonProperty private boolean needsSync; + @JsonProperty + private List uuids404; + public SendMessageResponse() {} public SendMessageResponse(boolean needsSync) { this.needsSync = needsSync; } + public SendMessageResponse(List uuids404) { + this.uuids404 = uuids404; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java new file mode 100644 index 000000000..2238253c0 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -0,0 +1,129 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.providers; + +import io.dropwizard.util.DataSizeUnit; +import java.io.IOException; +import java.io.InputStream; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; +import java.util.UUID; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.NoContentException; +import javax.ws.rs.ext.MessageBodyReader; +import javax.ws.rs.ext.Provider; +import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage; + +@Provider +@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE) +public class MultiRecipientMessageProvider implements MessageBodyReader { + + public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm"; + public static final int MAX_RECIPIENT_COUNT = 5000; + public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256)); + public static final byte VERSION = 0x22; + + @Override + public boolean isReadable(Class type, Type genericType, Annotation[] annotations, MediaType mediaType) { + return MEDIA_TYPE.equals(mediaType.toString()) && MultiRecipientMessage.class.isAssignableFrom(type); + } + + @Override + public MultiRecipientMessage readFrom(Class type, Type genericType, Annotation[] annotations, + MediaType mediaType, MultivaluedMap httpHeaders, InputStream entityStream) + throws IOException, WebApplicationException { + int versionByte = entityStream.read(); + if (versionByte == -1) { + throw new NoContentException("Empty body not allowed"); + } + if (versionByte != VERSION) { + throw new BadRequestException("Unsupported version"); + } + long count = readVarint(entityStream); + if (count > MAX_RECIPIENT_COUNT) { + throw new BadRequestException("Maximum recipient count exceeded"); + } + MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)]; + for (int i = 0; i < Math.toIntExact(count); i++) { + UUID uuid = readUuid(entityStream); + long deviceId = readVarint(entityStream); + byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); + if (perRecipientKeyMaterial.length != 48) { + throw new IOException("Failed to read expected number of key material bytes for a recipient"); + } + recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, perRecipientKeyMaterial); + } + + // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than + // this it'll return an error back. We just need to limit how many we'll accept here. + byte[] commonPayload = entityStream.readNBytes(MAX_MESSAGE_SIZE); + if (commonPayload.length < 32) { + throw new IOException("Failed to read expected number of common key material bytes"); + } + return new MultiRecipientMessage(recipients, commonPayload); + } + + /** + * Reads a UUID in network byte order and converts to a UUID object. + */ + private UUID readUuid(InputStream stream) throws IOException { + byte[] buffer = new byte[8]; + + int read = stream.read(buffer); + if (read != 8) { + throw new IOException("Insufficient bytes for UUID"); + } + long msb = convertNetworkByteOrderToLong(buffer); + + read = stream.read(buffer); + if (read != 8) { + throw new IOException("Insufficient bytes for UUID"); + } + long lsb = convertNetworkByteOrderToLong(buffer); + + return new UUID(msb, lsb); + } + + private long convertNetworkByteOrderToLong(byte[] buffer) { + long result = 0; + for (int i = 0; i < 8; i++) { + result = (result << (i * 8)) | (buffer[i] & 0xFFL); + } + return result; + } + + /** + * Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An + * {@code IOException} is thrown if the stream ends before we finish reading the varint. + * + * @return the varint value + */ + private long readVarint(InputStream stream) throws IOException, WebApplicationException { + boolean hasMore = true; + int currentOffset = 0; + int result = 0; + while (hasMore) { + if (currentOffset >= 64) { + throw new BadRequestException("varint is too large"); + } + int b = stream.read(); + if (b == -1) { + throw new IOException("Missing byte " + (currentOffset / 7) + " of varint"); + } + if (currentOffset == 63 && (b & 0xFE) != 0) { + throw new BadRequestException("varint is too large"); + } + hasMore = (b & 0x80) != 0; + result |= (b & 0x7F) << currentOffset; + currentOffset += 7; + } + return result; + } +}