Add logic to handle sending a common payload to multiple recipients
This commit is contained in:
		
							parent
							
								
									f117d9ff4d
								
							
						
					
					
						commit
						c448c37cc9
					
				| 
						 | 
				
			
			@ -117,6 +117,7 @@ import org.whispersystems.textsecuregcm.metrics.NstatCounters;
 | 
			
		|||
import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
 | 
			
		||||
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
 | 
			
		||||
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
 | 
			
		||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
 | 
			
		||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
 | 
			
		||||
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
 | 
			
		||||
import org.whispersystems.textsecuregcm.push.APNSender;
 | 
			
		||||
| 
						 | 
				
			
			@ -477,14 +478,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
 | 
			
		|||
    environment.servlets().addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager))
 | 
			
		||||
        .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
 | 
			
		||||
 | 
			
		||||
    environment.jersey().register(MultiRecipientMessageProvider.class);
 | 
			
		||||
    environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP));
 | 
			
		||||
 | 
			
		||||
    environment.jersey().register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter,
 | 
			
		||||
                                                                                      DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)));
 | 
			
		||||
    environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)));
 | 
			
		||||
 | 
			
		||||
    environment.jersey().register(new TimestampResponseFilter());
 | 
			
		||||
 | 
			
		||||
    environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, usernamesManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator, verifyExperimentEnrollmentManager));
 | 
			
		||||
    environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices()));
 | 
			
		||||
    environment.jersey().register(new DirectoryController(directoryCredentialsGenerator));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,31 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2021 Signal Messenger, LLC
 | 
			
		||||
 * SPDX-License-Identifier: AGPL-3.0-only
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.whispersystems.textsecuregcm.auth;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import javax.ws.rs.WebApplicationException;
 | 
			
		||||
import javax.ws.rs.core.Response;
 | 
			
		||||
import javax.ws.rs.core.Response.Status;
 | 
			
		||||
import org.whispersystems.textsecuregcm.util.Base64;
 | 
			
		||||
 | 
			
		||||
public class CombinedUnidentifiedSenderAccessKeys {
 | 
			
		||||
  private final byte[] combinedUnidentifiedSenderAccessKeys;
 | 
			
		||||
 | 
			
		||||
  public CombinedUnidentifiedSenderAccessKeys(String header) {
 | 
			
		||||
    try {
 | 
			
		||||
      this.combinedUnidentifiedSenderAccessKeys = Base64.decode(header);
 | 
			
		||||
      if (this.combinedUnidentifiedSenderAccessKeys == null || this.combinedUnidentifiedSenderAccessKeys.length != 16) {
 | 
			
		||||
        throw new WebApplicationException("Invalid combined unidentified sender access keys", Status.UNAUTHORIZED);
 | 
			
		||||
      }
 | 
			
		||||
    } catch (IOException e) {
 | 
			
		||||
      throw new WebApplicationException(e, Response.Status.UNAUTHORIZED);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public byte[] getAccessKeys() {
 | 
			
		||||
    return combinedUnidentifiedSenderAccessKeys;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -20,17 +20,24 @@ import io.lettuce.core.ScriptOutputType;
 | 
			
		|||
import io.micrometer.core.instrument.Metrics;
 | 
			
		||||
import io.micrometer.core.instrument.Tag;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.security.MessageDigest;
 | 
			
		||||
import java.time.Duration;
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Base64;
 | 
			
		||||
import java.util.HashSet;
 | 
			
		||||
import java.util.LinkedList;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
import java.util.Optional;
 | 
			
		||||
import java.util.Random;
 | 
			
		||||
import java.util.Set;
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
import java.util.concurrent.ScheduledExecutorService;
 | 
			
		||||
import java.util.concurrent.TimeUnit;
 | 
			
		||||
import java.util.concurrent.atomic.AtomicBoolean;
 | 
			
		||||
import java.util.function.Function;
 | 
			
		||||
import java.util.stream.Collectors;
 | 
			
		||||
import javax.validation.Valid;
 | 
			
		||||
import javax.ws.rs.Consumes;
 | 
			
		||||
import javax.ws.rs.DELETE;
 | 
			
		||||
| 
						 | 
				
			
			@ -40,26 +47,33 @@ import javax.ws.rs.PUT;
 | 
			
		|||
import javax.ws.rs.Path;
 | 
			
		||||
import javax.ws.rs.PathParam;
 | 
			
		||||
import javax.ws.rs.Produces;
 | 
			
		||||
import javax.ws.rs.QueryParam;
 | 
			
		||||
import javax.ws.rs.WebApplicationException;
 | 
			
		||||
import javax.ws.rs.core.MediaType;
 | 
			
		||||
import javax.ws.rs.core.Response;
 | 
			
		||||
import javax.ws.rs.core.Response.Status;
 | 
			
		||||
import org.apache.commons.lang3.StringUtils;
 | 
			
		||||
import org.slf4j.Logger;
 | 
			
		||||
import org.slf4j.LoggerFactory;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.Anonymous;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
 | 
			
		||||
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
 | 
			
		||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope.Type;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.StaleDevices;
 | 
			
		||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
 | 
			
		||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
 | 
			
		||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
 | 
			
		||||
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
 | 
			
		||||
import org.whispersystems.textsecuregcm.push.MessageSender;
 | 
			
		||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
 | 
			
		||||
| 
						 | 
				
			
			@ -89,6 +103,7 @@ public class MessageController {
 | 
			
		|||
  private final Meter          identifiedMeter                  = metricRegistry.meter(name(getClass(), "delivery", "identified"  ));
 | 
			
		||||
  private final Meter          rejectOver256kibMessageMeter     = metricRegistry.meter(name(getClass(), "rejectOver256kibMessage"));
 | 
			
		||||
  private final Timer          sendMessageInternalTimer         = metricRegistry.timer(name(getClass(), "sendMessageInternal"));
 | 
			
		||||
  private final Timer          sendCommonMessageInternalTimer   = metricRegistry.timer(name(getClass(), "sendCommonMessageInternal"));
 | 
			
		||||
  private final Histogram      outgoingMessageListSizeHistogram = metricRegistry.histogram(name(getClass(), "outgoingMessageListSize"));
 | 
			
		||||
 | 
			
		||||
  private final RateLimiters                rateLimiters;
 | 
			
		||||
| 
						 | 
				
			
			@ -295,6 +310,99 @@ public class MessageController {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Timed
 | 
			
		||||
  @Path("/multi_recipient")
 | 
			
		||||
  @PUT
 | 
			
		||||
  @Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
 | 
			
		||||
  @Produces(MediaType.APPLICATION_JSON)
 | 
			
		||||
  public Response sendMultiRecipientMessage(
 | 
			
		||||
      @HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys,
 | 
			
		||||
      @HeaderParam("User-Agent") String userAgent,
 | 
			
		||||
      @HeaderParam("X-Forwarded-For") String forwardedFor,
 | 
			
		||||
      @QueryParam("online") boolean online,
 | 
			
		||||
      @QueryParam("ts") long timestamp,
 | 
			
		||||
      @Valid MultiRecipientMessage multiRecipientMessage) {
 | 
			
		||||
 | 
			
		||||
    unidentifiedMeter.mark(multiRecipientMessage.getRecipients().length);
 | 
			
		||||
 | 
			
		||||
    Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
 | 
			
		||||
        .map(Recipient::getUuid)
 | 
			
		||||
        .distinct()
 | 
			
		||||
        .collect(Collectors.toMap(Function.identity(), uuid -> {
 | 
			
		||||
          Optional<Account> account = accountsManager.get(uuid);
 | 
			
		||||
          if (account.isEmpty()) {
 | 
			
		||||
            throw new WebApplicationException(Status.NOT_FOUND);
 | 
			
		||||
          }
 | 
			
		||||
          return account.get();
 | 
			
		||||
        }));
 | 
			
		||||
    checkAccessKeys(accessKeys, uuidToAccountMap);
 | 
			
		||||
 | 
			
		||||
    try {
 | 
			
		||||
      for (Account account : uuidToAccountMap.values()) {
 | 
			
		||||
        Set<Long> deviceIds = Arrays.stream(multiRecipientMessage.getRecipients())
 | 
			
		||||
            .filter(recipient -> recipient.getUuid().equals(account.getUuid()))
 | 
			
		||||
            .map(Recipient::getDeviceId)
 | 
			
		||||
            .collect(Collectors.toSet());
 | 
			
		||||
        validateCompleteDeviceList(account, deviceIds, false);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      List<Tag> tags = List.of(
 | 
			
		||||
          UserAgentTagUtil.getPlatformTag(userAgent),
 | 
			
		||||
          Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(online)),
 | 
			
		||||
          Tag.of(SENDER_TYPE_TAG_NAME, "unidentified"));
 | 
			
		||||
      List<UUID> uuids404 = new ArrayList<>();
 | 
			
		||||
      for (Recipient recipient : multiRecipientMessage.getRecipients()) {
 | 
			
		||||
 | 
			
		||||
        Account destinationAccount = uuidToAccountMap.get(recipient.getUuid());
 | 
			
		||||
        // we asserted this must be true in validateCompleteDeviceList
 | 
			
		||||
        //noinspection OptionalGetWithoutIsPresent
 | 
			
		||||
        Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).get();
 | 
			
		||||
        Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
 | 
			
		||||
        try {
 | 
			
		||||
          sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient,
 | 
			
		||||
              multiRecipientMessage.getCommonPayload());
 | 
			
		||||
        } catch (NoSuchUserException e) {
 | 
			
		||||
          uuids404.add(destinationAccount.getUuid());
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      return Response.ok(new SendMessageResponse(uuids404)).build();
 | 
			
		||||
    } catch (MismatchedDevicesException e) {
 | 
			
		||||
      throw new WebApplicationException(Response
 | 
			
		||||
          .status(409)
 | 
			
		||||
          .type(MediaType.APPLICATION_JSON_TYPE)
 | 
			
		||||
          .entity(new MismatchedDevices(e.getMissingDevices(), e.getExtraDevices()))
 | 
			
		||||
          .build());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) {
 | 
			
		||||
    AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
 | 
			
		||||
    byte[] empty = new byte[16];
 | 
			
		||||
    byte[] combinedUnknownAccessKeys = uuidToAccountMap.values().stream()
 | 
			
		||||
        .map(Account::getUnidentifiedAccessKey)
 | 
			
		||||
        .map(accessKey -> {
 | 
			
		||||
          if (accessKey.isEmpty()) {
 | 
			
		||||
            throwUnauthorized.set(true);
 | 
			
		||||
            return empty;
 | 
			
		||||
          }
 | 
			
		||||
          return accessKey.get();
 | 
			
		||||
        })
 | 
			
		||||
        .reduce(new byte[16], (bytes, bytes2) -> {
 | 
			
		||||
          if (bytes.length != bytes2.length) {
 | 
			
		||||
            throwUnauthorized.set(true);
 | 
			
		||||
            return bytes;
 | 
			
		||||
          }
 | 
			
		||||
          for (int i = 0; i < bytes.length; i++) {
 | 
			
		||||
            bytes[i] ^= bytes2[i];
 | 
			
		||||
          }
 | 
			
		||||
          return bytes;
 | 
			
		||||
        });
 | 
			
		||||
    if (throwUnauthorized.get()
 | 
			
		||||
        || !MessageDigest.isEqual(combinedUnknownAccessKeys, accessKeys.getAccessKeys())) {
 | 
			
		||||
      throw new WebApplicationException(Status.UNAUTHORIZED);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private Response declineDelivery(final IncomingMessageList messages, final Account source, final Account destination) {
 | 
			
		||||
    Metrics.counter(DECLINED_DELIVERY_COUNTER, SENDER_COUNTRY_TAG_NAME, Util.getCountryCode(source.getNumber())).increment();
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -464,6 +572,34 @@ public class MessageController {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private void sendMessage(Account destinationAccount, Device destinationDevice, long timestamp, boolean online,
 | 
			
		||||
      Recipient recipient, byte[] commonPayload) throws NoSuchUserException {
 | 
			
		||||
    try (final Timer.Context ignored = sendCommonMessageInternalTimer.time()) {
 | 
			
		||||
      Envelope.Builder messageBuilder = Envelope.newBuilder();
 | 
			
		||||
      long serverTimestamp = System.currentTimeMillis();
 | 
			
		||||
      byte[] recipientKeyMaterial = recipient.getPerRecipientKeyMaterial();
 | 
			
		||||
 | 
			
		||||
      byte[] payload = new byte[1 + recipientKeyMaterial.length + commonPayload.length];
 | 
			
		||||
      payload[0] = MultiRecipientMessageProvider.VERSION;
 | 
			
		||||
      System.arraycopy(recipientKeyMaterial, 0, payload, 1, recipientKeyMaterial.length);
 | 
			
		||||
      System.arraycopy(commonPayload, 0, payload, 1 + recipientKeyMaterial.length, payload.length);
 | 
			
		||||
 | 
			
		||||
      messageBuilder
 | 
			
		||||
          .setType(Type.UNIDENTIFIED_SENDER)
 | 
			
		||||
          .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
 | 
			
		||||
          .setServerTimestamp(serverTimestamp)
 | 
			
		||||
          .setContent(ByteString.copyFrom(payload));
 | 
			
		||||
 | 
			
		||||
      messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
 | 
			
		||||
    } catch (NotPushRegisteredException e) {
 | 
			
		||||
      if (destinationDevice.isMaster()) {
 | 
			
		||||
        throw new NoSuchUserException(e);
 | 
			
		||||
      } else {
 | 
			
		||||
        logger.debug("Not registered", e);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
 | 
			
		||||
      throws StaleDevicesException
 | 
			
		||||
  {
 | 
			
		||||
| 
						 | 
				
			
			@ -485,22 +621,24 @@ public class MessageController {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage)
 | 
			
		||||
      throws MismatchedDevicesException {
 | 
			
		||||
    Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet());
 | 
			
		||||
    validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  private void validateCompleteDeviceList(Account account,
 | 
			
		||||
                                          List<IncomingMessage> messages,
 | 
			
		||||
                                          Set<Long> messageDeviceIds,
 | 
			
		||||
                                          boolean isSyncMessage)
 | 
			
		||||
      throws MismatchedDevicesException
 | 
			
		||||
  {
 | 
			
		||||
    Set<Long> messageDeviceIds = new HashSet<>();
 | 
			
		||||
    Set<Long> accountDeviceIds = new HashSet<>();
 | 
			
		||||
 | 
			
		||||
    List<Long> missingDeviceIds = new LinkedList<>();
 | 
			
		||||
    List<Long> extraDeviceIds   = new LinkedList<>();
 | 
			
		||||
 | 
			
		||||
    for (IncomingMessage message : messages) {
 | 
			
		||||
      messageDeviceIds.add(message.getDestinationDeviceId());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (Device device : account.getDevices()) {
 | 
			
		||||
   for (Device device : account.getDevices()) {
 | 
			
		||||
      if (device.isEnabled() &&
 | 
			
		||||
          !(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId()))
 | 
			
		||||
      {
 | 
			
		||||
| 
						 | 
				
			
			@ -512,9 +650,9 @@ public class MessageController {
 | 
			
		|||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (IncomingMessage message : messages) {
 | 
			
		||||
      if (!accountDeviceIds.contains(message.getDestinationDeviceId())) {
 | 
			
		||||
        extraDeviceIds.add(message.getDestinationDeviceId());
 | 
			
		||||
    for (Long deviceId : messageDeviceIds) {
 | 
			
		||||
      if (!accountDeviceIds.contains(deviceId)) {
 | 
			
		||||
        extraDeviceIds.add(deviceId);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,67 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2021 Signal Messenger, LLC
 | 
			
		||||
 * SPDX-License-Identifier: AGPL-3.0-only
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.whispersystems.textsecuregcm.entities;
 | 
			
		||||
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
import javax.validation.constraints.Min;
 | 
			
		||||
import javax.validation.constraints.NotNull;
 | 
			
		||||
import javax.validation.constraints.Size;
 | 
			
		||||
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
 | 
			
		||||
 | 
			
		||||
public class MultiRecipientMessage {
 | 
			
		||||
 | 
			
		||||
  public static class Recipient {
 | 
			
		||||
 | 
			
		||||
    @NotNull
 | 
			
		||||
    private final UUID uuid;
 | 
			
		||||
 | 
			
		||||
    @Min(1)
 | 
			
		||||
    private final long deviceId;
 | 
			
		||||
 | 
			
		||||
    @Size(min = 48, max = 48)
 | 
			
		||||
    @NotNull
 | 
			
		||||
    private final byte[] perRecipientKeyMaterial;
 | 
			
		||||
 | 
			
		||||
    public Recipient(UUID uuid, long deviceId, byte[] perRecipientKeyMaterial) {
 | 
			
		||||
      this.uuid = uuid;
 | 
			
		||||
      this.deviceId = deviceId;
 | 
			
		||||
      this.perRecipientKeyMaterial = perRecipientKeyMaterial;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UUID getUuid() {
 | 
			
		||||
      return uuid;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public long getDeviceId() {
 | 
			
		||||
      return deviceId;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public byte[] getPerRecipientKeyMaterial() {
 | 
			
		||||
      return perRecipientKeyMaterial;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @NotNull
 | 
			
		||||
  @Size(min = 1, max = MultiRecipientMessageProvider.MAX_RECIPIENT_COUNT)
 | 
			
		||||
  private final Recipient[] recipients;
 | 
			
		||||
 | 
			
		||||
  @NotNull
 | 
			
		||||
  @Min(32)
 | 
			
		||||
  private final byte[] commonPayload;
 | 
			
		||||
 | 
			
		||||
  public MultiRecipientMessage(Recipient[] recipients, byte[] commonPayload) {
 | 
			
		||||
    this.recipients = recipients;
 | 
			
		||||
    this.commonPayload = commonPayload;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public Recipient[] getRecipients() {
 | 
			
		||||
    return recipients;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public byte[] getCommonPayload() {
 | 
			
		||||
    return commonPayload;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -6,16 +6,24 @@
 | 
			
		|||
package org.whispersystems.textsecuregcm.entities;
 | 
			
		||||
 | 
			
		||||
import com.fasterxml.jackson.annotation.JsonProperty;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
 | 
			
		||||
public class SendMessageResponse {
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private boolean needsSync;
 | 
			
		||||
 | 
			
		||||
  @JsonProperty
 | 
			
		||||
  private List<UUID> uuids404;
 | 
			
		||||
 | 
			
		||||
  public SendMessageResponse() {}
 | 
			
		||||
 | 
			
		||||
  public SendMessageResponse(boolean needsSync) {
 | 
			
		||||
    this.needsSync = needsSync;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public SendMessageResponse(List<UUID> uuids404) {
 | 
			
		||||
    this.uuids404 = uuids404;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,129 @@
 | 
			
		|||
/*
 | 
			
		||||
 * Copyright 2021 Signal Messenger, LLC
 | 
			
		||||
 * SPDX-License-Identifier: AGPL-3.0-only
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package org.whispersystems.textsecuregcm.providers;
 | 
			
		||||
 | 
			
		||||
import io.dropwizard.util.DataSizeUnit;
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
import java.io.InputStream;
 | 
			
		||||
import java.lang.annotation.Annotation;
 | 
			
		||||
import java.lang.reflect.Type;
 | 
			
		||||
import java.util.UUID;
 | 
			
		||||
import javax.ws.rs.BadRequestException;
 | 
			
		||||
import javax.ws.rs.Consumes;
 | 
			
		||||
import javax.ws.rs.WebApplicationException;
 | 
			
		||||
import javax.ws.rs.core.MediaType;
 | 
			
		||||
import javax.ws.rs.core.MultivaluedMap;
 | 
			
		||||
import javax.ws.rs.core.NoContentException;
 | 
			
		||||
import javax.ws.rs.ext.MessageBodyReader;
 | 
			
		||||
import javax.ws.rs.ext.Provider;
 | 
			
		||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
 | 
			
		||||
 | 
			
		||||
@Provider
 | 
			
		||||
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
 | 
			
		||||
public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRecipientMessage> {
 | 
			
		||||
 | 
			
		||||
  public static final String MEDIA_TYPE = "application/vnd.signal-messenger.mrm";
 | 
			
		||||
  public static final int MAX_RECIPIENT_COUNT = 5000;
 | 
			
		||||
  public static final int MAX_MESSAGE_SIZE = Math.toIntExact(32 + DataSizeUnit.KIBIBYTES.toBytes(256));
 | 
			
		||||
  public static final byte VERSION = 0x22;
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public boolean isReadable(Class<?> type, Type genericType, Annotation[] annotations, MediaType mediaType) {
 | 
			
		||||
    return MEDIA_TYPE.equals(mediaType.toString()) && MultiRecipientMessage.class.isAssignableFrom(type);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Override
 | 
			
		||||
  public MultiRecipientMessage readFrom(Class<MultiRecipientMessage> type, Type genericType, Annotation[] annotations,
 | 
			
		||||
      MediaType mediaType, MultivaluedMap<String, String> httpHeaders, InputStream entityStream)
 | 
			
		||||
      throws IOException, WebApplicationException {
 | 
			
		||||
    int versionByte = entityStream.read();
 | 
			
		||||
    if (versionByte == -1) {
 | 
			
		||||
      throw new NoContentException("Empty body not allowed");
 | 
			
		||||
    }
 | 
			
		||||
    if (versionByte != VERSION) {
 | 
			
		||||
      throw new BadRequestException("Unsupported version");
 | 
			
		||||
    }
 | 
			
		||||
    long count = readVarint(entityStream);
 | 
			
		||||
    if (count > MAX_RECIPIENT_COUNT) {
 | 
			
		||||
      throw new BadRequestException("Maximum recipient count exceeded");
 | 
			
		||||
    }
 | 
			
		||||
    MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)];
 | 
			
		||||
    for (int i = 0; i < Math.toIntExact(count); i++) {
 | 
			
		||||
      UUID uuid = readUuid(entityStream);
 | 
			
		||||
      long deviceId = readVarint(entityStream);
 | 
			
		||||
      byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
 | 
			
		||||
      if (perRecipientKeyMaterial.length != 48) {
 | 
			
		||||
        throw new IOException("Failed to read expected number of key material bytes for a recipient");
 | 
			
		||||
      }
 | 
			
		||||
      recipients[i] = new MultiRecipientMessage.Recipient(uuid, deviceId, perRecipientKeyMaterial);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // caller is responsible for checking that the entity stream is at EOF when we return; if there are more bytes than
 | 
			
		||||
    // this it'll return an error back. We just need to limit how many we'll accept here.
 | 
			
		||||
    byte[] commonPayload = entityStream.readNBytes(MAX_MESSAGE_SIZE);
 | 
			
		||||
    if (commonPayload.length < 32) {
 | 
			
		||||
      throw new IOException("Failed to read expected number of common key material bytes");
 | 
			
		||||
    }
 | 
			
		||||
    return new MultiRecipientMessage(recipients, commonPayload);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Reads a UUID in network byte order and converts to a UUID object.
 | 
			
		||||
   */
 | 
			
		||||
  private UUID readUuid(InputStream stream) throws IOException {
 | 
			
		||||
    byte[] buffer = new byte[8];
 | 
			
		||||
 | 
			
		||||
    int read = stream.read(buffer);
 | 
			
		||||
    if (read != 8) {
 | 
			
		||||
      throw new IOException("Insufficient bytes for UUID");
 | 
			
		||||
    }
 | 
			
		||||
    long msb = convertNetworkByteOrderToLong(buffer);
 | 
			
		||||
 | 
			
		||||
    read = stream.read(buffer);
 | 
			
		||||
    if (read != 8) {
 | 
			
		||||
      throw new IOException("Insufficient bytes for UUID");
 | 
			
		||||
    }
 | 
			
		||||
    long lsb = convertNetworkByteOrderToLong(buffer);
 | 
			
		||||
 | 
			
		||||
    return new UUID(msb, lsb);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  private long convertNetworkByteOrderToLong(byte[] buffer) {
 | 
			
		||||
    long result = 0;
 | 
			
		||||
    for (int i = 0; i < 8; i++) {
 | 
			
		||||
      result = (result << (i * 8)) | (buffer[i] & 0xFFL);
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Reads a varint. A varint larger than 64 bits is rejected with a {@code WebApplicationException}. An
 | 
			
		||||
   * {@code IOException} is thrown if the stream ends before we finish reading the varint.
 | 
			
		||||
   *
 | 
			
		||||
   * @return the varint value
 | 
			
		||||
   */
 | 
			
		||||
  private long readVarint(InputStream stream) throws IOException, WebApplicationException {
 | 
			
		||||
    boolean hasMore = true;
 | 
			
		||||
    int currentOffset = 0;
 | 
			
		||||
    int result = 0;
 | 
			
		||||
    while (hasMore) {
 | 
			
		||||
      if (currentOffset >= 64) {
 | 
			
		||||
        throw new BadRequestException("varint is too large");
 | 
			
		||||
      }
 | 
			
		||||
      int b = stream.read();
 | 
			
		||||
      if (b == -1) {
 | 
			
		||||
        throw new IOException("Missing byte " + (currentOffset / 7) + " of varint");
 | 
			
		||||
      }
 | 
			
		||||
      if (currentOffset == 63 && (b & 0xFE) != 0) {
 | 
			
		||||
        throw new BadRequestException("varint is too large");
 | 
			
		||||
      }
 | 
			
		||||
      hasMore = (b & 0x80) != 0;
 | 
			
		||||
      result |= (b & 0x7F) << currentOffset;
 | 
			
		||||
      currentOffset += 7;
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue