Add routing for stories.
This commit is contained in:
		
							parent
							
								
									c2ab72c77e
								
							
						
					
					
						commit
						966c3a8f47
					
				|  | @ -68,6 +68,9 @@ public class RateLimitsConfiguration { | ||||||
|   @JsonProperty |   @JsonProperty | ||||||
|   private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0); |   private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0); | ||||||
| 
 | 
 | ||||||
|  |   @JsonProperty | ||||||
|  |   private RateLimitConfiguration stories = new RateLimitConfiguration(10_000, 10_000 / (24.0 * 60.0)); | ||||||
|  | 
 | ||||||
|   public RateLimitConfiguration getAutoBlock() { |   public RateLimitConfiguration getAutoBlock() { | ||||||
|     return autoBlock; |     return autoBlock; | ||||||
|   } |   } | ||||||
|  | @ -148,6 +151,8 @@ public class RateLimitsConfiguration { | ||||||
|     return checkAccountExistence; |     return checkAccountExistence; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   public RateLimitConfiguration getStories() { return stories; } | ||||||
|  | 
 | ||||||
|   public static class RateLimitConfiguration { |   public static class RateLimitConfiguration { | ||||||
|     @JsonProperty |     @JsonProperty | ||||||
|     private int bucketSize; |     private int bucketSize; | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import java.util.Base64; | ||||||
| import java.util.Collection; | import java.util.Collection; | ||||||
| import java.util.Collections; | import java.util.Collections; | ||||||
| import java.util.HashSet; | import java.util.HashSet; | ||||||
|  | import java.util.LinkedList; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import java.util.Map; | import java.util.Map; | ||||||
| import java.util.Objects; | import java.util.Objects; | ||||||
|  | @ -33,7 +34,9 @@ import java.util.concurrent.ExecutorService; | ||||||
| import java.util.concurrent.atomic.AtomicBoolean; | import java.util.concurrent.atomic.AtomicBoolean; | ||||||
| import java.util.function.Function; | import java.util.function.Function; | ||||||
| import java.util.stream.Collectors; | import java.util.stream.Collectors; | ||||||
|  | import java.util.stream.Stream; | ||||||
| import javax.annotation.Nonnull; | import javax.annotation.Nonnull; | ||||||
|  | import javax.annotation.Nullable; | ||||||
| import javax.validation.Valid; | import javax.validation.Valid; | ||||||
| import javax.validation.constraints.NotNull; | import javax.validation.constraints.NotNull; | ||||||
| import javax.ws.rs.BadRequestException; | import javax.ws.rs.BadRequestException; | ||||||
|  | @ -92,6 +95,7 @@ import org.whispersystems.textsecuregcm.util.Util; | ||||||
| import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; | import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; | ||||||
| import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; | import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; | ||||||
| import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; | import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; | ||||||
|  | import org.whispersystems.websocket.Stories; | ||||||
| 
 | 
 | ||||||
| @SuppressWarnings("OptionalUsedAsFieldOrParameterType") | @SuppressWarnings("OptionalUsedAsFieldOrParameterType") | ||||||
| @Path("/v1/messages") | @Path("/v1/messages") | ||||||
|  | @ -164,7 +168,9 @@ public class MessageController { | ||||||
|       @NotNull @Valid IncomingMessageList messages) |       @NotNull @Valid IncomingMessageList messages) | ||||||
|       throws RateLimitExceededException { |       throws RateLimitExceededException { | ||||||
| 
 | 
 | ||||||
|     if (source.isEmpty() && accessKey.isEmpty()) { |     boolean isStory = messages.story(); | ||||||
|  | 
 | ||||||
|  |     if (source.isEmpty() && accessKey.isEmpty() && !isStory) { | ||||||
|       throw new WebApplicationException(Response.Status.UNAUTHORIZED); |       throw new WebApplicationException(Response.Status.UNAUTHORIZED); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -204,11 +210,18 @@ public class MessageController { | ||||||
|         destination = source.map(AuthenticatedAccount::getAccount); |         destination = source.map(AuthenticatedAccount::getAccount); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); |       // Stories will be checked by the client; we bypass access checks here for stories. | ||||||
|  |       if (!isStory) { | ||||||
|  |         OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); | ||||||
|  |       } | ||||||
|       assert (destination.isPresent()); |       assert (destination.isPresent()); | ||||||
| 
 | 
 | ||||||
|       if (source.isPresent() && !isSyncMessage) { |       if (source.isPresent() && !isSyncMessage) { | ||||||
|         checkRateLimit(source.get(), destination.get(), userAgent); |         checkMessageRateLimit(source.get(), destination.get(), userAgent); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       if (isStory) { | ||||||
|  |         checkStoryRateLimit(destination.get()); | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       final Set<Long> excludedDeviceIds; |       final Set<Long> excludedDeviceIds; | ||||||
|  | @ -238,12 +251,12 @@ public class MessageController { | ||||||
| 
 | 
 | ||||||
|         if (destinationDevice.isPresent()) { |         if (destinationDevice.isPresent()) { | ||||||
|           Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); |           Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); | ||||||
|           sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), messages.urgent(), incomingMessage, userAgent); |           sendIndividualMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), isStory, messages.urgent(), incomingMessage, userAgent); | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|       return Response.ok(new SendMessageResponse( |       boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1; | ||||||
|           !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build(); |       return Response.ok(new SendMessageResponse(needsSync)).build(); | ||||||
|     } catch (NoSuchUserException e) { |     } catch (NoSuchUserException e) { | ||||||
|       throw new WebApplicationException(Response.status(404).build()); |       throw new WebApplicationException(Response.status(404).build()); | ||||||
|     } catch (MismatchedDevicesException e) { |     } catch (MismatchedDevicesException e) { | ||||||
|  | @ -260,6 +273,35 @@ public class MessageController { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Build mapping of accounts to devices/registration IDs. | ||||||
|  |    * <p> | ||||||
|  |    * Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive | ||||||
|  |    * stories. | ||||||
|  |    * | ||||||
|  |    * @param multiRecipientMessage | ||||||
|  |    * @param uuidToAccountMap | ||||||
|  |    * @return | ||||||
|  |    */ | ||||||
|  |   private Map<Account, Set<Pair<Long, Integer>>> buildDeviceIdAndRegistrationIdMap( | ||||||
|  |       MultiRecipientMessage multiRecipientMessage, | ||||||
|  |       Map<UUID, Account> uuidToAccountMap | ||||||
|  |   ) { | ||||||
|  | 
 | ||||||
|  |     Stream<Recipient> recipients = Arrays.stream(multiRecipientMessage.getRecipients()); | ||||||
|  | 
 | ||||||
|  |     return recipients.collect(Collectors.toMap( | ||||||
|  |         recipient -> uuidToAccountMap.get(recipient.getUuid()), | ||||||
|  |         recipient -> new HashSet<>( | ||||||
|  |             Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), | ||||||
|  |         (a, b) -> { | ||||||
|  |           a.addAll(b); | ||||||
|  |           return a; | ||||||
|  |         } | ||||||
|  |     )); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   @Timed |   @Timed | ||||||
|   @Path("/multi_recipient") |   @Path("/multi_recipient") | ||||||
|   @PUT |   @PUT | ||||||
|  | @ -267,43 +309,51 @@ public class MessageController { | ||||||
|   @Produces(MediaType.APPLICATION_JSON) |   @Produces(MediaType.APPLICATION_JSON) | ||||||
|   @FilterAbusiveMessages |   @FilterAbusiveMessages | ||||||
|   public Response sendMultiRecipientMessage( |   public Response sendMultiRecipientMessage( | ||||||
|       @HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys, |       @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, | ||||||
|       @HeaderParam("User-Agent") String userAgent, |       @HeaderParam("User-Agent") String userAgent, | ||||||
|       @HeaderParam("X-Forwarded-For") String forwardedFor, |       @HeaderParam("X-Forwarded-For") String forwardedFor, | ||||||
|       @QueryParam("online") boolean online, |       @QueryParam("online") boolean online, | ||||||
|       @QueryParam("ts") long timestamp, |       @QueryParam("ts") long timestamp, | ||||||
|  |       @QueryParam("story") boolean isStory, | ||||||
|       @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { |       @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { | ||||||
| 
 | 
 | ||||||
|     Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) |     Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) | ||||||
|         .map(Recipient::getUuid) |         .map(Recipient::getUuid) | ||||||
|         .distinct() |         .distinct() | ||||||
|         .collect(Collectors.toUnmodifiableMap(Function.identity(), uuid -> { |         .collect(Collectors.toUnmodifiableMap( | ||||||
|           Optional<Account> account = accountsManager.getByAccountIdentifier(uuid); |             Function.identity(), | ||||||
|           if (account.isEmpty()) { |             uuid -> accountsManager | ||||||
|             throw new WebApplicationException(Status.NOT_FOUND); |                 .getByAccountIdentifier(uuid) | ||||||
|           } |                 .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)))); | ||||||
|           return account.get(); |  | ||||||
|         })); |  | ||||||
|     checkAccessKeys(accessKeys, uuidToAccountMap); |  | ||||||
| 
 | 
 | ||||||
|     final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap = |     // Stories will be checked by the client; we bypass access checks here for stories. | ||||||
|         Arrays |     if (!isStory) { | ||||||
|             .stream(multiRecipientMessage.getRecipients()) |       checkAccessKeys(accessKeys, uuidToAccountMap); | ||||||
|             .collect(Collectors.toMap( |     } | ||||||
|                 recipient -> uuidToAccountMap.get(recipient.getUuid()), | 
 | ||||||
|                 recipient -> new HashSet<>( |     final Map<Account, Set<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap = | ||||||
|                     Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), |         buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, uuidToAccountMap); | ||||||
|                 (a, b) -> { | 
 | ||||||
|                   a.addAll(b); |     // We might filter out all the recipients of a story (if none have enabled stories). | ||||||
|                   return a; |     // In this case there is no error so we should just return 200 now. | ||||||
|                 } |     if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) { | ||||||
|             )); |       return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build(); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); |     Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); | ||||||
|     Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); |     Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); | ||||||
|     uuidToAccountMap.values().forEach(account -> { |     uuidToAccountMap.values().forEach(account -> { | ||||||
|       final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first) | 
 | ||||||
|           .collect(Collectors.toSet()); |       if (isStory) { | ||||||
|  |         checkStoryRateLimit(account); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap | ||||||
|  |         .getOrDefault(account, Collections.emptySet()) | ||||||
|  |         .stream() | ||||||
|  |         .map(Pair::first) | ||||||
|  |         .collect(Collectors.toSet()); | ||||||
|  | 
 | ||||||
|       try { |       try { | ||||||
|         DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); |         DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); | ||||||
| 
 | 
 | ||||||
|  | @ -351,8 +401,8 @@ public class MessageController { | ||||||
|             Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); |             Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); | ||||||
|             sentMessageCounter.increment(); |             sentMessageCounter.increment(); | ||||||
|             try { |             try { | ||||||
|               sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, |               sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, | ||||||
|                   multiRecipientMessage.getCommonPayload()); |                   recipient, multiRecipientMessage.getCommonPayload()); | ||||||
|             } catch (NoSuchUserException e) { |             } catch (NoSuchUserException e) { | ||||||
|               uuids404.add(destinationAccount.getUuid()); |               uuids404.add(destinationAccount.getUuid()); | ||||||
|             } |             } | ||||||
|  | @ -367,6 +417,10 @@ public class MessageController { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) { |   private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) { | ||||||
|  |     // We should not have null access keys when checking access; bail out early. | ||||||
|  |     if (accessKeys == null) { | ||||||
|  |       throw new WebApplicationException(Status.UNAUTHORIZED); | ||||||
|  |     } | ||||||
|     AtomicBoolean throwUnauthorized = new AtomicBoolean(false); |     AtomicBoolean throwUnauthorized = new AtomicBoolean(false); | ||||||
|     byte[] empty = new byte[16]; |     byte[] empty = new byte[16]; | ||||||
|     final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); |     final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); | ||||||
|  | @ -405,8 +459,11 @@ public class MessageController { | ||||||
|   @GET |   @GET | ||||||
|   @Produces(MediaType.APPLICATION_JSON) |   @Produces(MediaType.APPLICATION_JSON) | ||||||
|   public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth, |   public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth, | ||||||
|  |       @HeaderParam(Stories.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader, | ||||||
|       @HeaderParam("User-Agent") String userAgent) { |       @HeaderParam("User-Agent") String userAgent) { | ||||||
| 
 | 
 | ||||||
|  |     boolean shouldReceiveStories = Stories.parseReceiveStoriesHeader(receiveStoriesHeader); | ||||||
|  | 
 | ||||||
|     pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent); |     pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent); | ||||||
| 
 | 
 | ||||||
|     final OutgoingMessageEntityList outgoingMessages; |     final OutgoingMessageEntityList outgoingMessages; | ||||||
|  | @ -416,7 +473,12 @@ public class MessageController { | ||||||
|           auth.getAuthenticatedDevice().getId(), |           auth.getAuthenticatedDevice().getId(), | ||||||
|           false); |           false); | ||||||
| 
 | 
 | ||||||
|       outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream() |       Stream<Envelope> envelopes = messagesAndHasMore.first().stream(); | ||||||
|  |       if (!shouldReceiveStories) { | ||||||
|  |         envelopes = envelopes.filter(e -> !e.getStory()); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       outgoingMessages = new OutgoingMessageEntityList(envelopes | ||||||
|           .map(OutgoingMessageEntity::fromEnvelope) |           .map(OutgoingMessageEntity::fromEnvelope) | ||||||
|           .peek(outgoingMessageEntity -> MessageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), |           .peek(outgoingMessageEntity -> MessageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), | ||||||
|               outgoingMessageEntity)) |               outgoingMessageEntity)) | ||||||
|  | @ -514,12 +576,13 @@ public class MessageController { | ||||||
|         .build(); |         .build(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void sendMessage(Optional<AuthenticatedAccount> source, |   private void sendIndividualMessage(Optional<AuthenticatedAccount> source, | ||||||
|       Account destinationAccount, |       Account destinationAccount, | ||||||
|       Device destinationDevice, |       Device destinationDevice, | ||||||
|       UUID destinationUuid, |       UUID destinationUuid, | ||||||
|       long timestamp, |       long timestamp, | ||||||
|       boolean online, |       boolean online, | ||||||
|  |       boolean story, | ||||||
|       boolean urgent, |       boolean urgent, | ||||||
|       IncomingMessage incomingMessage, |       IncomingMessage incomingMessage, | ||||||
|       String userAgentString) |       String userAgentString) | ||||||
|  | @ -532,6 +595,7 @@ public class MessageController { | ||||||
|             source.map(AuthenticatedAccount::getAccount).orElse(null), |             source.map(AuthenticatedAccount::getAccount).orElse(null), | ||||||
|             source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), |             source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), | ||||||
|             timestamp == 0 ? System.currentTimeMillis() : timestamp, |             timestamp == 0 ? System.currentTimeMillis() : timestamp, | ||||||
|  |             story, | ||||||
|             urgent); |             urgent); | ||||||
|       } catch (final IllegalArgumentException e) { |       } catch (final IllegalArgumentException e) { | ||||||
|         logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); |         logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); | ||||||
|  | @ -545,10 +609,11 @@ public class MessageController { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void sendMessage(Account destinationAccount, |   private void sendCommonPayloadMessage(Account destinationAccount, | ||||||
|       Device destinationDevice, |       Device destinationDevice, | ||||||
|       long timestamp, |       long timestamp, | ||||||
|       boolean online, |       boolean online, | ||||||
|  |       boolean story, | ||||||
|       Recipient recipient, |       Recipient recipient, | ||||||
|       byte[] commonPayload) throws NoSuchUserException { |       byte[] commonPayload) throws NoSuchUserException { | ||||||
|     try { |     try { | ||||||
|  | @ -566,6 +631,7 @@ public class MessageController { | ||||||
|           .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) |           .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) | ||||||
|           .setServerTimestamp(serverTimestamp) |           .setServerTimestamp(serverTimestamp) | ||||||
|           .setContent(ByteString.copyFrom(payload)) |           .setContent(ByteString.copyFrom(payload)) | ||||||
|  |           .setStory(story) | ||||||
|           .setDestinationUuid(destinationAccount.getUuid().toString()); |           .setDestinationUuid(destinationAccount.getUuid().toString()); | ||||||
| 
 | 
 | ||||||
|       messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); |       messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); | ||||||
|  | @ -578,7 +644,14 @@ public class MessageController { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private void checkRateLimit(AuthenticatedAccount source, Account destination, String userAgent) |   private void checkStoryRateLimit(Account destination) { | ||||||
|  |     try { | ||||||
|  |       rateLimiters.getMessagesLimiter().validate(destination.getUuid()); | ||||||
|  |     } catch (final RateLimitExceededException e) { | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private void checkMessageRateLimit(AuthenticatedAccount source, Account destination, String userAgent) | ||||||
|       throws RateLimitExceededException { |       throws RateLimitExceededException { | ||||||
|     final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); |     final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,6 +15,10 @@ public class AccountMismatchedDevices { | ||||||
|   @JsonProperty |   @JsonProperty | ||||||
|   public final MismatchedDevices devices; |   public final MismatchedDevices devices; | ||||||
| 
 | 
 | ||||||
|  |   public String toString() { | ||||||
|  |     return "AccountMismatchedDevices(" + uuid + ", " + devices + ")"; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) { |   public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) { | ||||||
|     this.uuid = uuid; |     this.uuid = uuid; | ||||||
|     this.devices = devices; |     this.devices = devices; | ||||||
|  |  | ||||||
|  | @ -15,6 +15,10 @@ public class AccountStaleDevices { | ||||||
|   @JsonProperty |   @JsonProperty | ||||||
|   public final StaleDevices devices; |   public final StaleDevices devices; | ||||||
| 
 | 
 | ||||||
|  |   public String toString() { | ||||||
|  |     return "AccountStaleDevices(" + uuid + ", " + devices + ")"; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { |   public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { | ||||||
|     this.uuid = uuid; |     this.uuid = uuid; | ||||||
|     this.devices = devices; |     this.devices = devices; | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio | ||||||
|       @Nullable Account sourceAccount, |       @Nullable Account sourceAccount, | ||||||
|       @Nullable Long sourceDeviceId, |       @Nullable Long sourceDeviceId, | ||||||
|       final long timestamp, |       final long timestamp, | ||||||
|  |       final boolean story, | ||||||
|       final boolean urgent) { |       final boolean urgent) { | ||||||
| 
 | 
 | ||||||
|     final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); |     final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); | ||||||
|  | @ -31,6 +32,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio | ||||||
|         .setTimestamp(timestamp) |         .setTimestamp(timestamp) | ||||||
|         .setServerTimestamp(System.currentTimeMillis()) |         .setServerTimestamp(System.currentTimeMillis()) | ||||||
|         .setDestinationUuid(destinationUuid.toString()) |         .setDestinationUuid(destinationUuid.toString()) | ||||||
|  |         .setStory(story) | ||||||
|         .setUrgent(urgent); |         .setUrgent(urgent); | ||||||
| 
 | 
 | ||||||
|     if (sourceAccount != null && sourceDeviceId != null) { |     if (sourceAccount != null && sourceDeviceId != null) { | ||||||
|  |  | ||||||
|  | @ -11,14 +11,15 @@ import javax.validation.Valid; | ||||||
| import javax.validation.constraints.NotNull; | import javax.validation.constraints.NotNull; | ||||||
| 
 | 
 | ||||||
| public record IncomingMessageList(@NotNull @Valid List<@NotNull IncomingMessage> messages, | public record IncomingMessageList(@NotNull @Valid List<@NotNull IncomingMessage> messages, | ||||||
|                                   boolean online, boolean urgent, long timestamp) { |                                   boolean online, boolean urgent, boolean story, long timestamp) { | ||||||
| 
 | 
 | ||||||
|   @JsonCreator |   @JsonCreator | ||||||
|   public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNull IncomingMessage> messages, |   public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNull IncomingMessage> messages, | ||||||
|       @JsonProperty("online") boolean online, |       @JsonProperty("online") boolean online, | ||||||
|       @JsonProperty("urgent") Boolean urgent, |       @JsonProperty("urgent") Boolean urgent, | ||||||
|  |       @JsonProperty("story") Boolean story, | ||||||
|       @JsonProperty("timestamp") long timestamp) { |       @JsonProperty("timestamp") long timestamp) { | ||||||
| 
 | 
 | ||||||
|     this(messages, online, urgent == null || urgent, timestamp); |     this(messages, online, urgent == null || urgent, story != null && story, timestamp); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -21,6 +21,10 @@ public class MismatchedDevices { | ||||||
|   @VisibleForTesting |   @VisibleForTesting | ||||||
|   public MismatchedDevices() {} |   public MismatchedDevices() {} | ||||||
| 
 | 
 | ||||||
|  |   public String toString() { | ||||||
|  |     return "MismatchedDevices(" + missingDevices + ", " + extraDevices + ")"; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public MismatchedDevices(List<Long> missingDevices, List<Long> extraDevices) { |   public MismatchedDevices(List<Long> missingDevices, List<Long> extraDevices) { | ||||||
|     this.missingDevices = missingDevices; |     this.missingDevices = missingDevices; | ||||||
|     this.extraDevices   = extraDevices; |     this.extraDevices   = extraDevices; | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ import javax.annotation.Nullable; | ||||||
| 
 | 
 | ||||||
| public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, | public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, | ||||||
|                                     UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, |                                     UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, | ||||||
|                                     long serverTimestamp, boolean urgent) { |                                     long serverTimestamp, boolean urgent, boolean story) { | ||||||
| 
 | 
 | ||||||
|   public MessageProtos.Envelope toEnvelope() { |   public MessageProtos.Envelope toEnvelope() { | ||||||
|     final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() |     final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() | ||||||
|  | @ -22,6 +22,7 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab | ||||||
|         .setServerTimestamp(serverTimestamp()) |         .setServerTimestamp(serverTimestamp()) | ||||||
|         .setDestinationUuid(destinationUuid().toString()) |         .setDestinationUuid(destinationUuid().toString()) | ||||||
|         .setServerGuid(guid().toString()) |         .setServerGuid(guid().toString()) | ||||||
|  |         .setStory(story) | ||||||
|         .setUrgent(urgent); |         .setUrgent(urgent); | ||||||
| 
 | 
 | ||||||
|     if (sourceUuid() != null) { |     if (sourceUuid() != null) { | ||||||
|  | @ -51,7 +52,8 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab | ||||||
|         envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, |         envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, | ||||||
|         envelope.getContent().toByteArray(), |         envelope.getContent().toByteArray(), | ||||||
|         envelope.getServerTimestamp(), |         envelope.getServerTimestamp(), | ||||||
|         envelope.getUrgent()); |         envelope.getUrgent(), | ||||||
|  |         envelope.getStory()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Override |   @Override | ||||||
|  | @ -63,16 +65,23 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab | ||||||
|       return false; |       return false; | ||||||
|     } |     } | ||||||
|     final OutgoingMessageEntity that = (OutgoingMessageEntity) o; |     final OutgoingMessageEntity that = (OutgoingMessageEntity) o; | ||||||
|     return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice |     return guid.equals(that.guid) && | ||||||
|         && serverTimestamp == that.serverTimestamp && guid.equals(that.guid) |         type == that.type && | ||||||
|         && Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid) |         timestamp == that.timestamp && | ||||||
|         && Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content) && urgent == that.urgent; |         Objects.equals(sourceUuid, that.sourceUuid) && | ||||||
|  |         sourceDevice == that.sourceDevice && | ||||||
|  |         destinationUuid.equals(that.destinationUuid) && | ||||||
|  |         Objects.equals(updatedPni, that.updatedPni) && | ||||||
|  |         Arrays.equals(content, that.content) && | ||||||
|  |         serverTimestamp == that.serverTimestamp && | ||||||
|  |         urgent == that.urgent && | ||||||
|  |         story == that.story; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Override |   @Override | ||||||
|   public int hashCode() { |   public int hashCode() { | ||||||
|     int result = Objects.hash(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, |     int result = Objects.hash(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, | ||||||
|         serverTimestamp, urgent); |         serverTimestamp, urgent, story); | ||||||
|     result = 31 * result + Arrays.hashCode(content); |     result = 31 * result + Arrays.hashCode(content); | ||||||
|     return result; |     return result; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ | ||||||
| package org.whispersystems.textsecuregcm.entities; | package org.whispersystems.textsecuregcm.entities; | ||||||
| 
 | 
 | ||||||
| import com.fasterxml.jackson.annotation.JsonProperty; | import com.fasterxml.jackson.annotation.JsonProperty; | ||||||
|  | import com.google.common.annotations.VisibleForTesting; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import java.util.UUID; | import java.util.UUID; | ||||||
| 
 | 
 | ||||||
|  | @ -16,6 +17,15 @@ public class SendMultiRecipientMessageResponse { | ||||||
|   public SendMultiRecipientMessageResponse() { |   public SendMultiRecipientMessageResponse() { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   public String toString() { | ||||||
|  |     return "SendMultiRecipientMessageResponse(" + uuids404 + ")"; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @VisibleForTesting | ||||||
|  |   public List<UUID> getUUIDs404() { | ||||||
|  |     return this.uuids404; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public SendMultiRecipientMessageResponse(final List<UUID> uuids404) { |   public SendMultiRecipientMessageResponse(final List<UUID> uuids404) { | ||||||
|     this.uuids404 = uuids404; |     this.uuids404 = uuids404; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -16,6 +16,10 @@ public class StaleDevices { | ||||||
| 
 | 
 | ||||||
|   public StaleDevices() {} |   public StaleDevices() {} | ||||||
| 
 | 
 | ||||||
|  |   public String toString() { | ||||||
|  |     return "StaleDevices(" + staleDevices + ")"; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public StaleDevices(List<Long> staleDevices) { |   public StaleDevices(List<Long> staleDevices) { | ||||||
|     this.staleDevices = staleDevices; |     this.staleDevices = staleDevices; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -37,6 +37,8 @@ public class RateLimiters { | ||||||
| 
 | 
 | ||||||
|   private final RateLimiter checkAccountExistenceLimiter; |   private final RateLimiter checkAccountExistenceLimiter; | ||||||
| 
 | 
 | ||||||
|  |   private final RateLimiter storiesLimiter; | ||||||
|  | 
 | ||||||
|   public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) { |   public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) { | ||||||
|     this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", |     this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", | ||||||
|                                                  config.getSmsDestination().getBucketSize(), |                                                  config.getSmsDestination().getBucketSize(), | ||||||
|  | @ -118,6 +120,10 @@ public class RateLimiters { | ||||||
|     this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence", |     this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence", | ||||||
|         config.getCheckAccountExistence().getBucketSize(), |         config.getCheckAccountExistence().getBucketSize(), | ||||||
|         config.getCheckAccountExistence().getLeakRatePerMinute()); |         config.getCheckAccountExistence().getLeakRatePerMinute()); | ||||||
|  | 
 | ||||||
|  |     this.storiesLimiter = new RateLimiter(cacheCluster, "stories", | ||||||
|  |         config.getStories().getBucketSize(), | ||||||
|  |         config.getStories().getLeakRatePerMinute()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public RateLimiter getAllocateDeviceLimiter() { |   public RateLimiter getAllocateDeviceLimiter() { | ||||||
|  | @ -199,4 +205,6 @@ public class RateLimiters { | ||||||
|   public RateLimiter getCheckAccountExistenceLimiter() { |   public RateLimiter getCheckAccountExistenceLimiter() { | ||||||
|     return checkAccountExistenceLimiter; |     return checkAccountExistenceLimiter; | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   public RateLimiter getStoriesLimiter() { return storiesLimiter; } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ import java.util.concurrent.ScheduledExecutorService; | ||||||
| import java.util.concurrent.ScheduledFuture; | import java.util.concurrent.ScheduledFuture; | ||||||
| import java.util.concurrent.TimeUnit; | import java.util.concurrent.TimeUnit; | ||||||
| import java.util.concurrent.atomic.AtomicReference; | import java.util.concurrent.atomic.AtomicReference; | ||||||
|  | import org.eclipse.jetty.websocket.api.UpgradeResponse; | ||||||
| import org.slf4j.Logger; | import org.slf4j.Logger; | ||||||
| import org.slf4j.LoggerFactory; | import org.slf4j.LoggerFactory; | ||||||
| import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; | import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; | ||||||
|  |  | ||||||
|  | @ -332,7 +332,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac | ||||||
|         final Envelope envelope = messages.get(i); |         final Envelope envelope = messages.get(i); | ||||||
|         final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); |         final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); | ||||||
| 
 | 
 | ||||||
|         if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { |         final boolean discard; | ||||||
|  |         if (isDesktopClient && envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE) { | ||||||
|  |           discard = true; | ||||||
|  |         } else if (envelope.getStory() && !client.shouldDeliverStories()) { | ||||||
|  |           discard = true; | ||||||
|  |         } else { | ||||||
|  |           discard = false; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         if (discard) { | ||||||
|           messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); |           messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); | ||||||
|           discardedMessagesMeter.mark(); |           discardedMessagesMeter.mark(); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -42,7 +42,8 @@ message Envelope { | ||||||
|   optional string destination_uuid = 13; |   optional string destination_uuid = 13; | ||||||
|   optional bool urgent = 14 [default=true]; |   optional bool urgent = 14 [default=true]; | ||||||
|   optional string updated_pni = 15; |   optional string updated_pni = 15; | ||||||
|   // next: 16 |   optional bool story = 16; // indicates that the content is a story. | ||||||
|  |   // next: 17 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| message ProvisioningUuid { | message ProvisioningUuid { | ||||||
|  |  | ||||||
|  | @ -32,15 +32,21 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; | ||||||
| import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; | import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; | ||||||
| import io.dropwizard.testing.junit5.ResourceExtension; | import io.dropwizard.testing.junit5.ResourceExtension; | ||||||
| import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; | import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; | ||||||
|  | import java.io.ByteArrayInputStream; | ||||||
|  | import java.io.InputStream; | ||||||
|  | import java.nio.ByteBuffer; | ||||||
|  | import java.nio.ByteOrder; | ||||||
| import java.util.Arrays; | import java.util.Arrays; | ||||||
| import java.util.Base64; | import java.util.Base64; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| import java.util.Optional; | import java.util.Optional; | ||||||
| import java.util.UUID; | import java.util.UUID; | ||||||
|  | import java.util.concurrent.Callable; | ||||||
| import java.util.concurrent.ExecutorService; | import java.util.concurrent.ExecutorService; | ||||||
| import java.util.concurrent.TimeUnit; | import java.util.concurrent.TimeUnit; | ||||||
| import java.util.stream.Stream; | import java.util.stream.Stream; | ||||||
| import javax.ws.rs.client.Entity; | import javax.ws.rs.client.Entity; | ||||||
|  | import javax.ws.rs.client.Invocation; | ||||||
| import javax.ws.rs.core.MediaType; | import javax.ws.rs.core.MediaType; | ||||||
| import javax.ws.rs.core.Response; | import javax.ws.rs.core.Response; | ||||||
| import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; | import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; | ||||||
|  | @ -62,11 +68,13 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; | ||||||
| import org.whispersystems.textsecuregcm.entities.MismatchedDevices; | import org.whispersystems.textsecuregcm.entities.MismatchedDevices; | ||||||
| import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; | import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; | ||||||
| import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; | import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; | ||||||
|  | import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; | ||||||
| import org.whispersystems.textsecuregcm.entities.SignedPreKey; | import org.whispersystems.textsecuregcm.entities.SignedPreKey; | ||||||
| import org.whispersystems.textsecuregcm.entities.StaleDevices; | import org.whispersystems.textsecuregcm.entities.StaleDevices; | ||||||
| import org.whispersystems.textsecuregcm.limits.RateLimiter; | import org.whispersystems.textsecuregcm.limits.RateLimiter; | ||||||
| import org.whispersystems.textsecuregcm.limits.RateLimiters; | import org.whispersystems.textsecuregcm.limits.RateLimiters; | ||||||
| import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; | import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; | ||||||
|  | import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; | ||||||
| import org.whispersystems.textsecuregcm.push.MessageSender; | import org.whispersystems.textsecuregcm.push.MessageSender; | ||||||
| import org.whispersystems.textsecuregcm.push.PushNotificationManager; | import org.whispersystems.textsecuregcm.push.PushNotificationManager; | ||||||
| import org.whispersystems.textsecuregcm.push.ReceiptSender; | import org.whispersystems.textsecuregcm.push.ReceiptSender; | ||||||
|  | @ -80,6 +88,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; | ||||||
| import org.whispersystems.textsecuregcm.tests.util.AuthHelper; | import org.whispersystems.textsecuregcm.tests.util.AuthHelper; | ||||||
| import org.whispersystems.textsecuregcm.util.Pair; | import org.whispersystems.textsecuregcm.util.Pair; | ||||||
| import org.whispersystems.textsecuregcm.util.SystemMapper; | import org.whispersystems.textsecuregcm.util.SystemMapper; | ||||||
|  | import org.whispersystems.websocket.Stories; | ||||||
| 
 | 
 | ||||||
| @ExtendWith(DropwizardExtensionsSupport.class) | @ExtendWith(DropwizardExtensionsSupport.class) | ||||||
| class MessageControllerTest { | class MessageControllerTest { | ||||||
|  | @ -87,10 +96,20 @@ class MessageControllerTest { | ||||||
|   private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; |   private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; | ||||||
|   private static final UUID   SINGLE_DEVICE_UUID      = UUID.randomUUID(); |   private static final UUID   SINGLE_DEVICE_UUID      = UUID.randomUUID(); | ||||||
|   private static final UUID   SINGLE_DEVICE_PNI       = UUID.randomUUID(); |   private static final UUID   SINGLE_DEVICE_PNI       = UUID.randomUUID(); | ||||||
|  |   private static final int SINGLE_DEVICE_ID1 = 1; | ||||||
|  |   private static final int SINGLE_DEVICE_REG_ID1 = 111; | ||||||
| 
 | 
 | ||||||
|   private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; |   private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; | ||||||
|   private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); |   private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); | ||||||
|   private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); |   private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); | ||||||
|  |   private static final int MULTI_DEVICE_ID1 = 1; | ||||||
|  |   private static final int MULTI_DEVICE_ID2 = 2; | ||||||
|  |   private static final int MULTI_DEVICE_ID3 = 3; | ||||||
|  |   private static final int MULTI_DEVICE_REG_ID1 = 222; | ||||||
|  |   private static final int MULTI_DEVICE_REG_ID2 = 333; | ||||||
|  |   private static final int MULTI_DEVICE_REG_ID3 = 444; | ||||||
|  | 
 | ||||||
|  |   private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); | ||||||
| 
 | 
 | ||||||
|   private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; |   private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; | ||||||
|   private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); |   private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); | ||||||
|  | @ -116,6 +135,7 @@ class MessageControllerTest { | ||||||
|       .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( |       .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( | ||||||
|           ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) |           ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) | ||||||
|       .addProvider(RateLimitExceededExceptionMapper.class) |       .addProvider(RateLimitExceededExceptionMapper.class) | ||||||
|  |       .addProvider(MultiRecipientMessageProvider.class) | ||||||
|       .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) |       .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) | ||||||
|       .addResource( |       .addResource( | ||||||
|           new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, |           new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, | ||||||
|  | @ -125,18 +145,18 @@ class MessageControllerTest { | ||||||
|   @BeforeEach |   @BeforeEach | ||||||
|   void setup() { |   void setup() { | ||||||
|     final List<Device> singleDeviceList = List.of( |     final List<Device> singleDeviceList = List.of( | ||||||
|         generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) |         generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|     final List<Device> multiDeviceList = List.of( |     final List<Device> multiDeviceList = List.of( | ||||||
|         generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), |         generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), | ||||||
|         generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), |         generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), | ||||||
|         generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) |         generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|     Account singleDeviceAccount  = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes()); |     Account singleDeviceAccount  = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); | ||||||
|     Account multiDeviceAccount   = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes()); |     Account multiDeviceAccount   = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES); | ||||||
|     internationalAccount         = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes()); |     internationalAccount         = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); | ||||||
| 
 | 
 | ||||||
|     when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); |     when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); | ||||||
|     when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); |     when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); | ||||||
|  | @ -171,7 +191,8 @@ class MessageControllerTest { | ||||||
|         rateLimiters, |         rateLimiters, | ||||||
|         rateLimiter, |         rateLimiter, | ||||||
|         pushNotificationManager, |         pushNotificationManager, | ||||||
|         reportMessageManager |         reportMessageManager, | ||||||
|  |         multiRecipientMessageExecutor | ||||||
|     ); |     ); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -270,7 +291,7 @@ class MessageControllerTest { | ||||||
|         resources.getJerseyTest() |         resources.getJerseyTest() | ||||||
|             .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) |             .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) | ||||||
|             .request() |             .request() | ||||||
|             .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) |             .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) | ||||||
|             .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), |             .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), | ||||||
|                     IncomingMessageList.class), |                     IncomingMessageList.class), | ||||||
|                 MediaType.APPLICATION_JSON_TYPE)); |                 MediaType.APPLICATION_JSON_TYPE)); | ||||||
|  | @ -412,8 +433,9 @@ class MessageControllerTest { | ||||||
|     verifyNoMoreInteractions(messageSender); |     verifyNoMoreInteractions(messageSender); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @ParameterizedTest | ||||||
|   void testGetMessages() { |   @MethodSource | ||||||
|  |   void testGetMessages(boolean receiveStories) { | ||||||
| 
 | 
 | ||||||
|     final long timestampOne = 313377; |     final long timestampOne = 313377; | ||||||
|     final long timestampTwo = 313388; |     final long timestampTwo = 313388; | ||||||
|  | @ -424,19 +446,15 @@ class MessageControllerTest { | ||||||
| 
 | 
 | ||||||
|     final UUID updatedPniOne = UUID.randomUUID(); |     final UUID updatedPniOne = UUID.randomUUID(); | ||||||
| 
 | 
 | ||||||
|     List<Envelope> messages = List.of( |     List<Envelope> envelopes = List.of( | ||||||
|         generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2, |         generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2, | ||||||
|             AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0), |             AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false), | ||||||
|         generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2, |         generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2, | ||||||
|             AuthHelper.VALID_UUID, null, null, 0) |             AuthHelper.VALID_UUID, null, null, 0, true) | ||||||
|     ); |     ); | ||||||
| 
 | 
 | ||||||
|     OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream() |  | ||||||
|         .map(OutgoingMessageEntity::fromEnvelope) |  | ||||||
|         .toList(), false); |  | ||||||
| 
 |  | ||||||
|     when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) |     when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) | ||||||
|         .thenReturn(new Pair<>(messages, false)); |         .thenReturn(new Pair<>(envelopes, false)); | ||||||
| 
 | 
 | ||||||
|     final String userAgent = "Test-UA"; |     final String userAgent = "Test-UA"; | ||||||
| 
 | 
 | ||||||
|  | @ -444,27 +462,39 @@ class MessageControllerTest { | ||||||
|         resources.getJerseyTest().target("/v1/messages/") |         resources.getJerseyTest().target("/v1/messages/") | ||||||
|             .request() |             .request() | ||||||
|             .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) |             .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) | ||||||
|  |             .header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false") | ||||||
|             .header("USer-Agent", userAgent) |             .header("USer-Agent", userAgent) | ||||||
|             .accept(MediaType.APPLICATION_JSON_TYPE) |             .accept(MediaType.APPLICATION_JSON_TYPE) | ||||||
|             .get(OutgoingMessageEntityList.class); |             .get(OutgoingMessageEntityList.class); | ||||||
| 
 | 
 | ||||||
|     assertEquals(response.messages().size(), 2); |     List<OutgoingMessageEntity> messages = response.messages(); | ||||||
|  |     int expectedSize = receiveStories ? 2 : 1; | ||||||
|  |     assertEquals(expectedSize, messages.size()); | ||||||
| 
 | 
 | ||||||
|     assertEquals(response.messages().get(0).timestamp(), timestampOne); |     OutgoingMessageEntity first = messages.get(0); | ||||||
|     assertEquals(response.messages().get(1).timestamp(), timestampTwo); |     assertEquals(first.timestamp(), timestampOne); | ||||||
|  |     assertEquals(first.guid(), messageGuidOne); | ||||||
|  |     assertEquals(first.sourceUuid(), sourceUuid); | ||||||
|  |     assertEquals(updatedPniOne, first.updatedPni()); | ||||||
| 
 | 
 | ||||||
|     assertEquals(response.messages().get(0).guid(), messageGuidOne); |     if (receiveStories) { | ||||||
|     assertEquals(response.messages().get(1).guid(), messageGuidTwo); |       OutgoingMessageEntity second = messages.get(1); | ||||||
| 
 |       assertEquals(second.timestamp(), timestampTwo); | ||||||
|     assertEquals(response.messages().get(0).sourceUuid(), sourceUuid); |       assertEquals(second.guid(), messageGuidTwo); | ||||||
|     assertEquals(response.messages().get(1).sourceUuid(), sourceUuid); |       assertEquals(second.sourceUuid(), sourceUuid); | ||||||
| 
 |       assertNull(second.updatedPni()); | ||||||
|     assertEquals(updatedPniOne, response.messages().get(0).updatedPni()); |     } | ||||||
|     assertNull(response.messages().get(1).updatedPni()); |  | ||||||
| 
 | 
 | ||||||
|     verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); |     verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   private static Stream<Arguments> testGetMessages() { | ||||||
|  |     return Stream.of( | ||||||
|  |         Arguments.of(true), | ||||||
|  |         Arguments.of(false) | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   @Test |   @Test | ||||||
|   void testGetMessagesBadAuth() { |   void testGetMessagesBadAuth() { | ||||||
|     final long timestampOne = 313377; |     final long timestampOne = 313377; | ||||||
|  | @ -644,9 +674,9 @@ class MessageControllerTest { | ||||||
|         resources.getJerseyTest() |         resources.getJerseyTest() | ||||||
|             .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) |             .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) | ||||||
|             .request() |             .request() | ||||||
|             .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) |             .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) | ||||||
|             .put(Entity.entity(new IncomingMessageList( |             .put(Entity.entity(new IncomingMessageList( | ||||||
|                     List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, |                     List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, false, | ||||||
|                     System.currentTimeMillis()), |                     System.currentTimeMillis()), | ||||||
|                 MediaType.APPLICATION_JSON_TYPE)); |                 MediaType.APPLICATION_JSON_TYPE)); | ||||||
| 
 | 
 | ||||||
|  | @ -686,14 +716,166 @@ class MessageControllerTest { | ||||||
|     ); |     ); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception { | ||||||
|  |     bb.putLong(msb);            // uuid (first 8 bytes) | ||||||
|  |     bb.putLong(lsb);            // uuid (last 8 bytes) | ||||||
|  |     int x = deviceId; | ||||||
|  |     // write the device-id in the 7-bit varint format we use, least significant bytes first. | ||||||
|  |     do { | ||||||
|  |       bb.put((byte)(x & 0x7f)); | ||||||
|  |       x = x >>> 7; | ||||||
|  |     } while (x != 0); | ||||||
|  |     bb.putShort((short) regId); // registration id short | ||||||
|  |     bb.put(new byte[48]);       // key material (48 bytes) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception { | ||||||
|  |     // initialize a binary payload according to our wire format | ||||||
|  |     ByteBuffer bb = ByteBuffer.wrap(buffer); | ||||||
|  |     bb.order(ByteOrder.BIG_ENDIAN); | ||||||
|  | 
 | ||||||
|  |     // determine how many recipient/device pairs we will be writing | ||||||
|  |     int count; | ||||||
|  |     if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; } | ||||||
|  |     else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; } | ||||||
|  |     else { throw new Exception("unknown UUID: " + recipientUUID); } | ||||||
|  | 
 | ||||||
|  |     // first write the header header | ||||||
|  |     bb.put(MultiRecipientMessageProvider.VERSION); // version byte | ||||||
|  |     bb.put((byte)count);                           // count varint, # of active devices for this user | ||||||
|  | 
 | ||||||
|  |     long msb = recipientUUID.getMostSignificantBits(); | ||||||
|  |     long lsb = recipientUUID.getLeastSignificantBits(); | ||||||
|  | 
 | ||||||
|  |     // write the recipient data for each recipient/device pair | ||||||
|  |     if (recipientUUID == MULTI_DEVICE_UUID) { | ||||||
|  |       writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1); | ||||||
|  |       writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2); | ||||||
|  |     } else { | ||||||
|  |       writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // now write the actual message body (empty for now) | ||||||
|  |     bb.put(new byte[39]);            // payload (variable but >= 32, 39 bytes here) | ||||||
|  | 
 | ||||||
|  |     // return the input stream | ||||||
|  |     return new ByteArrayInputStream(buffer, 0, bb.position()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @ParameterizedTest | ||||||
|  |   @MethodSource | ||||||
|  |   void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory) throws Exception { | ||||||
|  | 
 | ||||||
|  |     // initialize our binary payload and create an input stream | ||||||
|  |     byte[] buffer = new byte[2048]; | ||||||
|  |     InputStream stream = initializeMultiPayload(recipientUUID, buffer); | ||||||
|  | 
 | ||||||
|  |     // set up the entity to use in our PUT request | ||||||
|  |     Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); | ||||||
|  | 
 | ||||||
|  |     // start building the request | ||||||
|  |     Invocation.Builder bldr = resources | ||||||
|  |         .getJerseyTest() | ||||||
|  |         .target("/v1/messages/multi_recipient") | ||||||
|  |         .queryParam("online", true) | ||||||
|  |         .queryParam("ts", 1663798405641L) | ||||||
|  |         .queryParam("story", isStory) | ||||||
|  |         .request() | ||||||
|  |         .header("User-Agent", "FIXME"); | ||||||
|  | 
 | ||||||
|  |     // add access header if needed | ||||||
|  |     if (authorize) { | ||||||
|  |       String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES); | ||||||
|  |       bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // make the PUT request | ||||||
|  |     Response response = bldr.put(entity); | ||||||
|  | 
 | ||||||
|  |     // We have a 2x2x2 grid of possible situations based on: | ||||||
|  |     //   - recipient enabled stories? | ||||||
|  |     //   - sender is authorized? | ||||||
|  |     //   - message is a story? | ||||||
|  | 
 | ||||||
|  |     if (recipientUUID == MULTI_DEVICE_UUID) { | ||||||
|  |       // This is the case where the recipient has enabled stories. | ||||||
|  |       if(isStory) { | ||||||
|  |         // We are sending a story, so we ignore access checks and expect this | ||||||
|  |         // to go out to both the recipient's devices. | ||||||
|  |         checkGoodMultiRecipientResponse(response, 2); | ||||||
|  |       } else { | ||||||
|  |         // We are not sending a story, so we need to do access checks. | ||||||
|  |         if (authorize) { | ||||||
|  |           // When authorized we send a message to the recipient's devices. | ||||||
|  |           checkGoodMultiRecipientResponse(response, 2); | ||||||
|  |         } else { | ||||||
|  |           // When forbidden, we return a 401 error. | ||||||
|  |           checkBadMultiRecipientResponse(response, 401); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } else { | ||||||
|  |       // This is the case where the recipient has not enabled stories. | ||||||
|  |       if (isStory) { | ||||||
|  |         // We are sending a story, so we ignore access checks. | ||||||
|  |         // this recipient has one device. | ||||||
|  |         checkGoodMultiRecipientResponse(response, 1); | ||||||
|  |       } else { | ||||||
|  |         // We are not sending a story so check access. | ||||||
|  |         if (authorize) { | ||||||
|  |           // If allowed, send a message to the recipient's one device. | ||||||
|  |           checkGoodMultiRecipientResponse(response, 1); | ||||||
|  |         } else { | ||||||
|  |           // If forbidden, return a 401 error. | ||||||
|  |           checkBadMultiRecipientResponse(response, 401); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Arguments here are: recipient-UUID, is-authorized?, is-story? | ||||||
|  |   private static Stream<Arguments> testMultiRecipientMessage() { | ||||||
|  |     return Stream.of( | ||||||
|  |         Arguments.of(MULTI_DEVICE_UUID, false, true), | ||||||
|  |         Arguments.of(MULTI_DEVICE_UUID, false, false), | ||||||
|  |         Arguments.of(SINGLE_DEVICE_UUID, false, true), | ||||||
|  |         Arguments.of(SINGLE_DEVICE_UUID, false, false), | ||||||
|  |         Arguments.of(MULTI_DEVICE_UUID, true, true), | ||||||
|  |         Arguments.of(MULTI_DEVICE_UUID, true, false), | ||||||
|  |         Arguments.of(SINGLE_DEVICE_UUID, true, true), | ||||||
|  |         Arguments.of(SINGLE_DEVICE_UUID, true, false) | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { | ||||||
|  |     assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); | ||||||
|  |     verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); | ||||||
|  |     verify(multiRecipientMessageExecutor, never()).invokeAll(any()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception { | ||||||
|  |     assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); | ||||||
|  |     verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); | ||||||
|  |     ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class); | ||||||
|  |     verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture()); | ||||||
|  |     assert (captor.getValue().size() == expectedCount); | ||||||
|  |     SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); | ||||||
|  |     assert (smrmr.getUUIDs404().isEmpty()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, |   private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, | ||||||
|       int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { |       int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { | ||||||
|  |     return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, | ||||||
|  |       int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) { | ||||||
| 
 | 
 | ||||||
|     final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() |     final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() | ||||||
|         .setType(MessageProtos.Envelope.Type.forNumber(type)) |         .setType(MessageProtos.Envelope.Type.forNumber(type)) | ||||||
|         .setTimestamp(timestamp) |         .setTimestamp(timestamp) | ||||||
|         .setServerTimestamp(serverTimestamp) |         .setServerTimestamp(serverTimestamp) | ||||||
|         .setDestinationUuid(destinationUuid.toString()) |         .setDestinationUuid(destinationUuid.toString()) | ||||||
|  |         .setStory(story) | ||||||
|         .setServerGuid(guid.toString()); |         .setServerGuid(guid.toString()); | ||||||
| 
 | 
 | ||||||
|     if (sourceUuid != null) { |     if (sourceUuid != null) { | ||||||
|  |  | ||||||
|  | @ -36,7 +36,8 @@ class OutgoingMessageEntityTest { | ||||||
|         updatedPni, |         updatedPni, | ||||||
|         messageContent, |         messageContent, | ||||||
|         serverTimestamp, |         serverTimestamp, | ||||||
|         true); |         true, | ||||||
|  |         false); | ||||||
| 
 | 
 | ||||||
|     assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); |     assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -70,7 +70,7 @@ class MessageMetricsTest { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { |   private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { | ||||||
|     return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true); |     return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  |  | ||||||
|  | @ -0,0 +1,15 @@ | ||||||
|  | package org.whispersystems.websocket; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Class containing constants and shared logic for handling stories. | ||||||
|  |  * <p> | ||||||
|  |  * In particular, it defines the way we interpret the X-Signal-Receive-Stories header | ||||||
|  |  * which is used by both WebSockets and by the REST API. | ||||||
|  |  */ | ||||||
|  | public class Stories { | ||||||
|  |   public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories"; | ||||||
|  | 
 | ||||||
|  |   public static boolean parseReceiveStoriesHeader(String s) { | ||||||
|  |     return "true".equals(s); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -35,13 +35,12 @@ public class WebSocketClient { | ||||||
| 
 | 
 | ||||||
|   public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, |   public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, | ||||||
|                          WebSocketMessageFactory messageFactory, |                          WebSocketMessageFactory messageFactory, | ||||||
|                          Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) |                          Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) { | ||||||
|   { |     this.session = session; | ||||||
|     this.session              = session; |     this.remoteEndpoint = remoteEndpoint; | ||||||
|     this.remoteEndpoint       = remoteEndpoint; |     this.messageFactory = messageFactory; | ||||||
|     this.messageFactory       = messageFactory; |  | ||||||
|     this.pendingRequestMapper = pendingRequestMapper; |     this.pendingRequestMapper = pendingRequestMapper; | ||||||
|     this.created              = System.currentTimeMillis(); |     this.created = System.currentTimeMillis(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   public CompletableFuture<WebSocketResponseMessage> sendRequest(String verb, String path, |   public CompletableFuture<WebSocketResponseMessage> sendRequest(String verb, String path, | ||||||
|  | @ -92,6 +91,11 @@ public class WebSocketClient { | ||||||
|     session.close(code, message); |     session.close(code, message); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   public boolean shouldDeliverStories() { | ||||||
|  |     String value = session.getUpgradeRequest().getHeader(Stories.X_SIGNAL_RECEIVE_STORIES); | ||||||
|  |     return Stories.parseReceiveStoriesHeader(value); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   public void hardDisconnectQuietly() { |   public void hardDisconnectQuietly() { | ||||||
|     try { |     try { | ||||||
|       session.disconnect(); |       session.disconnect(); | ||||||
|  |  | ||||||
|  | @ -653,12 +653,14 @@ class WebSocketResourceProviderTest { | ||||||
|     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse(); |     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse(); | ||||||
|     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue(); |     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue(); | ||||||
|     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue(); |     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue(); | ||||||
|  |     assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Signal-Receive-Stories")).isTrue(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|   void testShouldIncludeRequestMessageHeader() { |   void testShouldIncludeRequestMessageHeader() { | ||||||
|     assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse(); |     assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse(); | ||||||
|     assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue(); |     assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue(); | ||||||
|  |     assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Signal-Receive-Stories")).isTrue(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   @Test |   @Test | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 erik-signal
						erik-signal