From 966c3a8f4784f24fbd1ae53258902b7100db9010 Mon Sep 17 00:00:00 2001 From: erik-signal <113138376+erik-signal@users.noreply.github.com> Date: Wed, 5 Oct 2022 10:32:10 -0400 Subject: [PATCH] Add routing for stories. --- .../RateLimitsConfiguration.java | 5 + .../controllers/MessageController.java | 143 +++++++--- .../entities/AccountMismatchedDevices.java | 4 + .../entities/AccountStaleDevices.java | 4 + .../entities/IncomingMessage.java | 2 + .../entities/IncomingMessageList.java | 5 +- .../entities/MismatchedDevices.java | 4 + .../entities/OutgoingMessageEntity.java | 23 +- .../SendMultiRecipientMessageResponse.java | 10 + .../textsecuregcm/entities/StaleDevices.java | 4 + .../textsecuregcm/limits/RateLimiters.java | 8 + .../AuthenticatedConnectListener.java | 1 + .../websocket/WebSocketConnection.java | 11 +- service/src/main/proto/TextSecure.proto | 3 +- .../controllers/MessageControllerTest.java | 246 +++++++++++++++--- .../entities/OutgoingMessageEntityTest.java | 3 +- .../metrics/MessageMetricsTest.java | 2 +- .../org/whispersystems/websocket/Stories.java | 15 ++ .../websocket/WebSocketClient.java | 16 +- .../WebSocketResourceProviderTest.java | 2 + 20 files changed, 425 insertions(+), 86 deletions(-) create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/Stories.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java index ee4aee1aa..4b836a0ed 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java @@ -68,6 +68,9 @@ public class RateLimitsConfiguration { @JsonProperty private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0); + @JsonProperty + private RateLimitConfiguration stories = new RateLimitConfiguration(10_000, 10_000 / (24.0 * 60.0)); + public RateLimitConfiguration getAutoBlock() { return autoBlock; } @@ -148,6 +151,8 @@ public class RateLimitsConfiguration { return checkAccountExistence; } + public RateLimitConfiguration getStories() { return stories; } + public static class RateLimitConfiguration { @JsonProperty private int bucketSize; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 450169e3a..443a65fad 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -22,6 +22,7 @@ import java.util.Base64; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -33,7 +34,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.BadRequestException; @@ -92,6 +95,7 @@ import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; +import org.whispersystems.websocket.Stories; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/messages") @@ -164,7 +168,9 @@ public class MessageController { @NotNull @Valid IncomingMessageList messages) throws RateLimitExceededException { - if (source.isEmpty() && accessKey.isEmpty()) { + boolean isStory = messages.story(); + + if (source.isEmpty() && accessKey.isEmpty() && !isStory) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } @@ -204,11 +210,18 @@ public class MessageController { destination = source.map(AuthenticatedAccount::getAccount); } - OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); + // Stories will be checked by the client; we bypass access checks here for stories. + if (!isStory) { + OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); + } assert (destination.isPresent()); if (source.isPresent() && !isSyncMessage) { - checkRateLimit(source.get(), destination.get(), userAgent); + checkMessageRateLimit(source.get(), destination.get(), userAgent); + } + + if (isStory) { + checkStoryRateLimit(destination.get()); } final Set excludedDeviceIds; @@ -238,12 +251,12 @@ public class MessageController { if (destinationDevice.isPresent()) { Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); - sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), messages.urgent(), incomingMessage, userAgent); + sendIndividualMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), isStory, messages.urgent(), incomingMessage, userAgent); } } - return Response.ok(new SendMessageResponse( - !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build(); + boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1; + return Response.ok(new SendMessageResponse(needsSync)).build(); } catch (NoSuchUserException e) { throw new WebApplicationException(Response.status(404).build()); } catch (MismatchedDevicesException e) { @@ -260,6 +273,35 @@ public class MessageController { } } + + /** + * Build mapping of accounts to devices/registration IDs. + *

+ * Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive + * stories. + * + * @param multiRecipientMessage + * @param uuidToAccountMap + * @return + */ + private Map>> buildDeviceIdAndRegistrationIdMap( + MultiRecipientMessage multiRecipientMessage, + Map uuidToAccountMap + ) { + + Stream recipients = Arrays.stream(multiRecipientMessage.getRecipients()); + + return recipients.collect(Collectors.toMap( + recipient -> uuidToAccountMap.get(recipient.getUuid()), + recipient -> new HashSet<>( + Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), + (a, b) -> { + a.addAll(b); + return a; + } + )); + } + @Timed @Path("/multi_recipient") @PUT @@ -267,43 +309,51 @@ public class MessageController { @Produces(MediaType.APPLICATION_JSON) @FilterAbusiveMessages public Response sendMultiRecipientMessage( - @HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys, + @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys, @HeaderParam("User-Agent") String userAgent, @HeaderParam("X-Forwarded-For") String forwardedFor, @QueryParam("online") boolean online, @QueryParam("ts") long timestamp, + @QueryParam("story") boolean isStory, @NotNull @Valid MultiRecipientMessage multiRecipientMessage) { Map uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) .map(Recipient::getUuid) .distinct() - .collect(Collectors.toUnmodifiableMap(Function.identity(), uuid -> { - Optional account = accountsManager.getByAccountIdentifier(uuid); - if (account.isEmpty()) { - throw new WebApplicationException(Status.NOT_FOUND); - } - return account.get(); - })); - checkAccessKeys(accessKeys, uuidToAccountMap); + .collect(Collectors.toUnmodifiableMap( + Function.identity(), + uuid -> accountsManager + .getByAccountIdentifier(uuid) + .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)))); - final Map>> accountToDeviceIdAndRegistrationIdMap = - Arrays - .stream(multiRecipientMessage.getRecipients()) - .collect(Collectors.toMap( - recipient -> uuidToAccountMap.get(recipient.getUuid()), - recipient -> new HashSet<>( - Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), - (a, b) -> { - a.addAll(b); - return a; - } - )); + // Stories will be checked by the client; we bypass access checks here for stories. + if (!isStory) { + checkAccessKeys(accessKeys, uuidToAccountMap); + } + + final Map>> accountToDeviceIdAndRegistrationIdMap = + buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, uuidToAccountMap); + + // We might filter out all the recipients of a story (if none have enabled stories). + // In this case there is no error so we should just return 200 now. + if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) { + return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build(); + } Collection accountMismatchedDevices = new ArrayList<>(); Collection accountStaleDevices = new ArrayList<>(); uuidToAccountMap.values().forEach(account -> { - final Set deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first) - .collect(Collectors.toSet()); + + if (isStory) { + checkStoryRateLimit(account); + } + + Set deviceIds = accountToDeviceIdAndRegistrationIdMap + .getOrDefault(account, Collections.emptySet()) + .stream() + .map(Pair::first) + .collect(Collectors.toSet()); + try { DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); @@ -351,8 +401,8 @@ public class MessageController { Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); sentMessageCounter.increment(); try { - sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, - multiRecipientMessage.getCommonPayload()); + sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory, + recipient, multiRecipientMessage.getCommonPayload()); } catch (NoSuchUserException e) { uuids404.add(destinationAccount.getUuid()); } @@ -367,6 +417,10 @@ public class MessageController { } private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map uuidToAccountMap) { + // We should not have null access keys when checking access; bail out early. + if (accessKeys == null) { + throw new WebApplicationException(Status.UNAUTHORIZED); + } AtomicBoolean throwUnauthorized = new AtomicBoolean(false); byte[] empty = new byte[16]; final Optional UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); @@ -405,8 +459,11 @@ public class MessageController { @GET @Produces(MediaType.APPLICATION_JSON) public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth, + @HeaderParam(Stories.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader, @HeaderParam("User-Agent") String userAgent) { + boolean shouldReceiveStories = Stories.parseReceiveStoriesHeader(receiveStoriesHeader); + pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent); final OutgoingMessageEntityList outgoingMessages; @@ -416,7 +473,12 @@ public class MessageController { auth.getAuthenticatedDevice().getId(), false); - outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream() + Stream envelopes = messagesAndHasMore.first().stream(); + if (!shouldReceiveStories) { + envelopes = envelopes.filter(e -> !e.getStory()); + } + + outgoingMessages = new OutgoingMessageEntityList(envelopes .map(OutgoingMessageEntity::fromEnvelope) .peek(outgoingMessageEntity -> MessageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), outgoingMessageEntity)) @@ -514,12 +576,13 @@ public class MessageController { .build(); } - private void sendMessage(Optional source, + private void sendIndividualMessage(Optional source, Account destinationAccount, Device destinationDevice, UUID destinationUuid, long timestamp, boolean online, + boolean story, boolean urgent, IncomingMessage incomingMessage, String userAgentString) @@ -532,6 +595,7 @@ public class MessageController { source.map(AuthenticatedAccount::getAccount).orElse(null), source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), timestamp == 0 ? System.currentTimeMillis() : timestamp, + story, urgent); } catch (final IllegalArgumentException e) { logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); @@ -545,10 +609,11 @@ public class MessageController { } } - private void sendMessage(Account destinationAccount, + private void sendCommonPayloadMessage(Account destinationAccount, Device destinationDevice, long timestamp, boolean online, + boolean story, Recipient recipient, byte[] commonPayload) throws NoSuchUserException { try { @@ -566,6 +631,7 @@ public class MessageController { .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) .setServerTimestamp(serverTimestamp) .setContent(ByteString.copyFrom(payload)) + .setStory(story) .setDestinationUuid(destinationAccount.getUuid().toString()); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); @@ -578,7 +644,14 @@ public class MessageController { } } - private void checkRateLimit(AuthenticatedAccount source, Account destination, String userAgent) + private void checkStoryRateLimit(Account destination) { + try { + rateLimiters.getMessagesLimiter().validate(destination.getUuid()); + } catch (final RateLimitExceededException e) { + } + } + + private void checkMessageRateLimit(AuthenticatedAccount source, Account destination, String userAgent) throws RateLimitExceededException { final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java index 992c6d0ab..4991355c6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountMismatchedDevices.java @@ -15,6 +15,10 @@ public class AccountMismatchedDevices { @JsonProperty public final MismatchedDevices devices; + public String toString() { + return "AccountMismatchedDevices(" + uuid + ", " + devices + ")"; + } + public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) { this.uuid = uuid; this.devices = devices; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java index 35d818ea0..bf1282fdc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountStaleDevices.java @@ -15,6 +15,10 @@ public class AccountStaleDevices { @JsonProperty public final StaleDevices devices; + public String toString() { + return "AccountStaleDevices(" + uuid + ", " + devices + ")"; + } + public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { this.uuid = uuid; this.devices = devices; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index fe6176e5f..7cbefe6c3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -17,6 +17,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio @Nullable Account sourceAccount, @Nullable Long sourceDeviceId, final long timestamp, + final boolean story, final boolean urgent) { final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); @@ -31,6 +32,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio .setTimestamp(timestamp) .setServerTimestamp(System.currentTimeMillis()) .setDestinationUuid(destinationUuid.toString()) + .setStory(story) .setUrgent(urgent); if (sourceAccount != null && sourceDeviceId != null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java index d5bb2af4e..7631f70f8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessageList.java @@ -11,14 +11,15 @@ import javax.validation.Valid; import javax.validation.constraints.NotNull; public record IncomingMessageList(@NotNull @Valid List<@NotNull IncomingMessage> messages, - boolean online, boolean urgent, long timestamp) { + boolean online, boolean urgent, boolean story, long timestamp) { @JsonCreator public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNull IncomingMessage> messages, @JsonProperty("online") boolean online, @JsonProperty("urgent") Boolean urgent, + @JsonProperty("story") Boolean story, @JsonProperty("timestamp") long timestamp) { - this(messages, online, urgent == null || urgent, timestamp); + this(messages, online, urgent == null || urgent, story != null && story, timestamp); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java index 981e9b2f5..7f45dc82b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java @@ -21,6 +21,10 @@ public class MismatchedDevices { @VisibleForTesting public MismatchedDevices() {} + public String toString() { + return "MismatchedDevices(" + missingDevices + ", " + extraDevices + ")"; + } + public MismatchedDevices(List missingDevices, List extraDevices) { this.missingDevices = missingDevices; this.extraDevices = extraDevices; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java index 63afaf1ef..4de88c0b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntity.java @@ -13,7 +13,7 @@ import javax.annotation.Nullable; public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, - long serverTimestamp, boolean urgent) { + long serverTimestamp, boolean urgent, boolean story) { public MessageProtos.Envelope toEnvelope() { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() @@ -22,6 +22,7 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab .setServerTimestamp(serverTimestamp()) .setDestinationUuid(destinationUuid().toString()) .setServerGuid(guid().toString()) + .setStory(story) .setUrgent(urgent); if (sourceUuid() != null) { @@ -51,7 +52,8 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.getContent().toByteArray(), envelope.getServerTimestamp(), - envelope.getUrgent()); + envelope.getUrgent(), + envelope.getStory()); } @Override @@ -63,16 +65,23 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab return false; } final OutgoingMessageEntity that = (OutgoingMessageEntity) o; - return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice - && serverTimestamp == that.serverTimestamp && guid.equals(that.guid) - && Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid) - && Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content) && urgent == that.urgent; + return guid.equals(that.guid) && + type == that.type && + timestamp == that.timestamp && + Objects.equals(sourceUuid, that.sourceUuid) && + sourceDevice == that.sourceDevice && + destinationUuid.equals(that.destinationUuid) && + Objects.equals(updatedPni, that.updatedPni) && + Arrays.equals(content, that.content) && + serverTimestamp == that.serverTimestamp && + urgent == that.urgent && + story == that.story; } @Override public int hashCode() { int result = Objects.hash(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, - serverTimestamp, urgent); + serverTimestamp, urgent, story); result = 31 * result + Arrays.hashCode(content); return result; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java index fd41082a8..62635d2ad 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SendMultiRecipientMessageResponse.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.annotations.VisibleForTesting; import java.util.List; import java.util.UUID; @@ -16,6 +17,15 @@ public class SendMultiRecipientMessageResponse { public SendMultiRecipientMessageResponse() { } + public String toString() { + return "SendMultiRecipientMessageResponse(" + uuids404 + ")"; + } + + @VisibleForTesting + public List getUUIDs404() { + return this.uuids404; + } + public SendMultiRecipientMessageResponse(final List uuids404) { this.uuids404 = uuids404; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java index 2d8e4c9ac..98be70197 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java @@ -16,6 +16,10 @@ public class StaleDevices { public StaleDevices() {} + public String toString() { + return "StaleDevices(" + staleDevices + ")"; + } + public StaleDevices(List staleDevices) { this.staleDevices = staleDevices; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 4dee74de5..adc625c7e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -37,6 +37,8 @@ public class RateLimiters { private final RateLimiter checkAccountExistenceLimiter; + private final RateLimiter storiesLimiter; + public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) { this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", config.getSmsDestination().getBucketSize(), @@ -118,6 +120,10 @@ public class RateLimiters { this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence", config.getCheckAccountExistence().getBucketSize(), config.getCheckAccountExistence().getLeakRatePerMinute()); + + this.storiesLimiter = new RateLimiter(cacheCluster, "stories", + config.getStories().getBucketSize(), + config.getStories().getLeakRatePerMinute()); } public RateLimiter getAllocateDeviceLimiter() { @@ -199,4 +205,6 @@ public class RateLimiters { public RateLimiter getCheckAccountExistenceLimiter() { return checkAccountExistenceLimiter; } + + public RateLimiter getStoriesLimiter() { return storiesLimiter; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 73e37f68b..e284e6530 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -15,6 +15,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import org.eclipse.jetty.websocket.api.UpgradeResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index f0da9987b..5e47bd429 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -332,7 +332,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final Envelope envelope = messages.get(i); final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); - if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { + final boolean discard; + if (isDesktopClient && envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE) { + discard = true; + } else if (envelope.getStory() && !client.shouldDeliverStories()) { + discard = true; + } else { + discard = false; + } + + if (discard) { messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); discardedMessagesMeter.mark(); diff --git a/service/src/main/proto/TextSecure.proto b/service/src/main/proto/TextSecure.proto index 29c05e376..978f8118c 100644 --- a/service/src/main/proto/TextSecure.proto +++ b/service/src/main/proto/TextSecure.proto @@ -42,7 +42,8 @@ message Envelope { optional string destination_uuid = 13; optional bool urgent = 14 [default=true]; optional string updated_pni = 15; - // next: 16 + optional bool story = 16; // indicates that the content is a story. + // next: 17 } message ProvisioningUuid { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 8fc7b1f6e..fb40d7c2a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -32,15 +32,21 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import java.util.Base64; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import javax.ws.rs.client.Entity; +import javax.ws.rs.client.Invocation; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; @@ -62,11 +68,13 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; +import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.ReceiptSender; @@ -80,6 +88,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.Stories; @ExtendWith(DropwizardExtensionsSupport.class) class MessageControllerTest { @@ -87,10 +96,20 @@ class MessageControllerTest { private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); + private static final int SINGLE_DEVICE_ID1 = 1; + private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); + private static final int MULTI_DEVICE_ID1 = 1; + private static final int MULTI_DEVICE_ID2 = 2; + private static final int MULTI_DEVICE_ID3 = 3; + private static final int MULTI_DEVICE_REG_ID1 = 222; + private static final int MULTI_DEVICE_REG_ID2 = 333; + private static final int MULTI_DEVICE_REG_ID3 = 444; + + private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes(); private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); @@ -116,6 +135,7 @@ class MessageControllerTest { .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .addProvider(RateLimitExceededExceptionMapper.class) + .addProvider(MultiRecipientMessageProvider.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, @@ -125,18 +145,18 @@ class MessageControllerTest { @BeforeEach void setup() { final List singleDeviceList = List.of( - generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) + generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) ); final List multiDeviceList = List.of( - generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), - generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) + generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), + generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) ); - Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes()); - Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes()); - internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes()); + Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); + Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES); + internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); @@ -171,7 +191,8 @@ class MessageControllerTest { rateLimiters, rateLimiter, pushNotificationManager, - reportMessageManager + reportMessageManager, + multiRecipientMessageExecutor ); } @@ -270,7 +291,7 @@ class MessageControllerTest { resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class), MediaType.APPLICATION_JSON_TYPE)); @@ -412,8 +433,9 @@ class MessageControllerTest { verifyNoMoreInteractions(messageSender); } - @Test - void testGetMessages() { + @ParameterizedTest + @MethodSource + void testGetMessages(boolean receiveStories) { final long timestampOne = 313377; final long timestampTwo = 313388; @@ -424,19 +446,15 @@ class MessageControllerTest { final UUID updatedPniOne = UUID.randomUUID(); - List messages = List.of( + List envelopes = List.of( generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2, - AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0), + AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false), generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2, - AuthHelper.VALID_UUID, null, null, 0) + AuthHelper.VALID_UUID, null, null, 0, true) ); - OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream() - .map(OutgoingMessageEntity::fromEnvelope) - .toList(), false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) - .thenReturn(new Pair<>(messages, false)); + .thenReturn(new Pair<>(envelopes, false)); final String userAgent = "Test-UA"; @@ -444,27 +462,39 @@ class MessageControllerTest { resources.getJerseyTest().target("/v1/messages/") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false") .header("USer-Agent", userAgent) .accept(MediaType.APPLICATION_JSON_TYPE) .get(OutgoingMessageEntityList.class); - assertEquals(response.messages().size(), 2); + List messages = response.messages(); + int expectedSize = receiveStories ? 2 : 1; + assertEquals(expectedSize, messages.size()); - assertEquals(response.messages().get(0).timestamp(), timestampOne); - assertEquals(response.messages().get(1).timestamp(), timestampTwo); + OutgoingMessageEntity first = messages.get(0); + assertEquals(first.timestamp(), timestampOne); + assertEquals(first.guid(), messageGuidOne); + assertEquals(first.sourceUuid(), sourceUuid); + assertEquals(updatedPniOne, first.updatedPni()); - assertEquals(response.messages().get(0).guid(), messageGuidOne); - assertEquals(response.messages().get(1).guid(), messageGuidTwo); - - assertEquals(response.messages().get(0).sourceUuid(), sourceUuid); - assertEquals(response.messages().get(1).sourceUuid(), sourceUuid); - - assertEquals(updatedPniOne, response.messages().get(0).updatedPni()); - assertNull(response.messages().get(1).updatedPni()); + if (receiveStories) { + OutgoingMessageEntity second = messages.get(1); + assertEquals(second.timestamp(), timestampTwo); + assertEquals(second.guid(), messageGuidTwo); + assertEquals(second.sourceUuid(), sourceUuid); + assertNull(second.updatedPni()); + } verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); } + private static Stream testGetMessages() { + return Stream.of( + Arguments.of(true), + Arguments.of(false) + ); + } + @Test void testGetMessagesBadAuth() { final long timestampOne = 313377; @@ -644,9 +674,9 @@ class MessageControllerTest { resources.getJerseyTest() .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .request() - .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) + .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(new IncomingMessageList( - List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, + List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, false, System.currentTimeMillis()), MediaType.APPLICATION_JSON_TYPE)); @@ -686,14 +716,166 @@ class MessageControllerTest { ); } + private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception { + bb.putLong(msb); // uuid (first 8 bytes) + bb.putLong(lsb); // uuid (last 8 bytes) + int x = deviceId; + // write the device-id in the 7-bit varint format we use, least significant bytes first. + do { + bb.put((byte)(x & 0x7f)); + x = x >>> 7; + } while (x != 0); + bb.putShort((short) regId); // registration id short + bb.put(new byte[48]); // key material (48 bytes) + } + + private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception { + // initialize a binary payload according to our wire format + ByteBuffer bb = ByteBuffer.wrap(buffer); + bb.order(ByteOrder.BIG_ENDIAN); + + // determine how many recipient/device pairs we will be writing + int count; + if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; } + else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; } + else { throw new Exception("unknown UUID: " + recipientUUID); } + + // first write the header header + bb.put(MultiRecipientMessageProvider.VERSION); // version byte + bb.put((byte)count); // count varint, # of active devices for this user + + long msb = recipientUUID.getMostSignificantBits(); + long lsb = recipientUUID.getLeastSignificantBits(); + + // write the recipient data for each recipient/device pair + if (recipientUUID == MULTI_DEVICE_UUID) { + writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1); + writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2); + } else { + writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1); + } + + // now write the actual message body (empty for now) + bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here) + + // return the input stream + return new ByteArrayInputStream(buffer, 0, bb.position()); + } + + @ParameterizedTest + @MethodSource + void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory) throws Exception { + + // initialize our binary payload and create an input stream + byte[] buffer = new byte[2048]; + InputStream stream = initializeMultiPayload(recipientUUID, buffer); + + // set up the entity to use in our PUT request + Entity entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); + + // start building the request + Invocation.Builder bldr = resources + .getJerseyTest() + .target("/v1/messages/multi_recipient") + .queryParam("online", true) + .queryParam("ts", 1663798405641L) + .queryParam("story", isStory) + .request() + .header("User-Agent", "FIXME"); + + // add access header if needed + if (authorize) { + String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES); + bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes); + } + + // make the PUT request + Response response = bldr.put(entity); + + // We have a 2x2x2 grid of possible situations based on: + // - recipient enabled stories? + // - sender is authorized? + // - message is a story? + + if (recipientUUID == MULTI_DEVICE_UUID) { + // This is the case where the recipient has enabled stories. + if(isStory) { + // We are sending a story, so we ignore access checks and expect this + // to go out to both the recipient's devices. + checkGoodMultiRecipientResponse(response, 2); + } else { + // We are not sending a story, so we need to do access checks. + if (authorize) { + // When authorized we send a message to the recipient's devices. + checkGoodMultiRecipientResponse(response, 2); + } else { + // When forbidden, we return a 401 error. + checkBadMultiRecipientResponse(response, 401); + } + } + } else { + // This is the case where the recipient has not enabled stories. + if (isStory) { + // We are sending a story, so we ignore access checks. + // this recipient has one device. + checkGoodMultiRecipientResponse(response, 1); + } else { + // We are not sending a story so check access. + if (authorize) { + // If allowed, send a message to the recipient's one device. + checkGoodMultiRecipientResponse(response, 1); + } else { + // If forbidden, return a 401 error. + checkBadMultiRecipientResponse(response, 401); + } + } + } + } + + // Arguments here are: recipient-UUID, is-authorized?, is-story? + private static Stream testMultiRecipientMessage() { + return Stream.of( + Arguments.of(MULTI_DEVICE_UUID, false, true), + Arguments.of(MULTI_DEVICE_UUID, false, false), + Arguments.of(SINGLE_DEVICE_UUID, false, true), + Arguments.of(SINGLE_DEVICE_UUID, false, false), + Arguments.of(MULTI_DEVICE_UUID, true, true), + Arguments.of(MULTI_DEVICE_UUID, true, false), + Arguments.of(SINGLE_DEVICE_UUID, true, true), + Arguments.of(SINGLE_DEVICE_UUID, true, false) + ); + } + + private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception { + assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode))); + verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); + verify(multiRecipientMessageExecutor, never()).invokeAll(any()); + } + + private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception { + assertThat("Unexpected response", response.getStatus(), is(equalTo(200))); + verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean()); + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture()); + assert (captor.getValue().size() == expectedCount); + SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class); + assert (smrmr.getUUIDs404().isEmpty()); + } + private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { + return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); + } + + private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, + int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() .setType(MessageProtos.Envelope.Type.forNumber(type)) .setTimestamp(timestamp) .setServerTimestamp(serverTimestamp) .setDestinationUuid(destinationUuid.toString()) + .setStory(story) .setServerGuid(guid.toString()); if (sourceUuid != null) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index fd65f5f31..ef22f8727 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -36,7 +36,8 @@ class OutgoingMessageEntityTest { updatedPni, messageContent, serverTimestamp, - true); + true, + false); assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java index 34489dfc3..97fc3db3d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MessageMetricsTest.java @@ -70,7 +70,7 @@ class MessageMetricsTest { } private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { - return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true); + return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false); } @Test diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/Stories.java b/websocket-resources/src/main/java/org/whispersystems/websocket/Stories.java new file mode 100644 index 000000000..9cf859126 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/Stories.java @@ -0,0 +1,15 @@ +package org.whispersystems.websocket; + +/** + * Class containing constants and shared logic for handling stories. + *

+ * In particular, it defines the way we interpret the X-Signal-Receive-Stories header + * which is used by both WebSockets and by the REST API. + */ +public class Stories { + public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories"; + + public static boolean parseReceiveStoriesHeader(String s) { + return "true".equals(s); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java index 2d8384789..8977d4c2e 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketClient.java @@ -35,13 +35,12 @@ public class WebSocketClient { public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, WebSocketMessageFactory messageFactory, - Map> pendingRequestMapper) - { - this.session = session; - this.remoteEndpoint = remoteEndpoint; - this.messageFactory = messageFactory; + Map> pendingRequestMapper) { + this.session = session; + this.remoteEndpoint = remoteEndpoint; + this.messageFactory = messageFactory; this.pendingRequestMapper = pendingRequestMapper; - this.created = System.currentTimeMillis(); + this.created = System.currentTimeMillis(); } public CompletableFuture sendRequest(String verb, String path, @@ -92,6 +91,11 @@ public class WebSocketClient { session.close(code, message); } + public boolean shouldDeliverStories() { + String value = session.getUpgradeRequest().getHeader(Stories.X_SIGNAL_RECEIVE_STORIES); + return Stories.parseReceiveStoriesHeader(value); + } + public void hardDisconnectQuietly() { try { session.disconnect(); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index ac5bd45ce..dc3a84443 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -653,12 +653,14 @@ class WebSocketResourceProviderTest { assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse(); assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue(); assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue(); + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Signal-Receive-Stories")).isTrue(); } @Test void testShouldIncludeRequestMessageHeader() { assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse(); assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue(); + assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Signal-Receive-Stories")).isTrue(); } @Test