Add routing for stories.

This commit is contained in:
erik-signal 2022-10-05 10:32:10 -04:00 committed by Erik Osheim
parent c2ab72c77e
commit 966c3a8f47
20 changed files with 425 additions and 86 deletions

View File

@ -68,6 +68,9 @@ public class RateLimitsConfiguration {
@JsonProperty @JsonProperty
private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0); private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0);
@JsonProperty
private RateLimitConfiguration stories = new RateLimitConfiguration(10_000, 10_000 / (24.0 * 60.0));
public RateLimitConfiguration getAutoBlock() { public RateLimitConfiguration getAutoBlock() {
return autoBlock; return autoBlock;
} }
@ -148,6 +151,8 @@ public class RateLimitsConfiguration {
return checkAccountExistence; return checkAccountExistence;
} }
public RateLimitConfiguration getStories() { return stories; }
public static class RateLimitConfiguration { public static class RateLimitConfiguration {
@JsonProperty @JsonProperty
private int bucketSize; private int bucketSize;

View File

@ -22,6 +22,7 @@ import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -33,7 +34,9 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException; import javax.ws.rs.BadRequestException;
@ -92,6 +95,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.Stories;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/messages") @Path("/v1/messages")
@ -164,7 +168,9 @@ public class MessageController {
@NotNull @Valid IncomingMessageList messages) @NotNull @Valid IncomingMessageList messages)
throws RateLimitExceededException { throws RateLimitExceededException {
if (source.isEmpty() && accessKey.isEmpty()) { boolean isStory = messages.story();
if (source.isEmpty() && accessKey.isEmpty() && !isStory) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
@ -204,11 +210,18 @@ public class MessageController {
destination = source.map(AuthenticatedAccount::getAccount); destination = source.map(AuthenticatedAccount::getAccount);
} }
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); // Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
}
assert (destination.isPresent()); assert (destination.isPresent());
if (source.isPresent() && !isSyncMessage) { if (source.isPresent() && !isSyncMessage) {
checkRateLimit(source.get(), destination.get(), userAgent); checkMessageRateLimit(source.get(), destination.get(), userAgent);
}
if (isStory) {
checkStoryRateLimit(destination.get());
} }
final Set<Long> excludedDeviceIds; final Set<Long> excludedDeviceIds;
@ -238,12 +251,12 @@ public class MessageController {
if (destinationDevice.isPresent()) { if (destinationDevice.isPresent()) {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), messages.urgent(), incomingMessage, userAgent); sendIndividualMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), isStory, messages.urgent(), incomingMessage, userAgent);
} }
} }
return Response.ok(new SendMessageResponse( boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1;
!isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build(); return Response.ok(new SendMessageResponse(needsSync)).build();
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build()); throw new WebApplicationException(Response.status(404).build());
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
@ -260,6 +273,35 @@ public class MessageController {
} }
} }
/**
* Build mapping of accounts to devices/registration IDs.
* <p>
* Messages that are stories will only be sent to the subset of recipients who have indicated they want to receive
* stories.
*
* @param multiRecipientMessage
* @param uuidToAccountMap
* @return
*/
private Map<Account, Set<Pair<Long, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage,
Map<UUID, Account> uuidToAccountMap
) {
Stream<Recipient> recipients = Arrays.stream(multiRecipientMessage.getRecipients());
return recipients.collect(Collectors.toMap(
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>(
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))),
(a, b) -> {
a.addAll(b);
return a;
}
));
}
@Timed @Timed
@Path("/multi_recipient") @Path("/multi_recipient")
@PUT @PUT
@ -267,43 +309,51 @@ public class MessageController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@FilterAbusiveMessages @FilterAbusiveMessages
public Response sendMultiRecipientMessage( public Response sendMultiRecipientMessage(
@HeaderParam(OptionalAccess.UNIDENTIFIED) CombinedUnidentifiedSenderAccessKeys accessKeys, @HeaderParam(OptionalAccess.UNIDENTIFIED) @Nullable CombinedUnidentifiedSenderAccessKeys accessKeys,
@HeaderParam("User-Agent") String userAgent, @HeaderParam("User-Agent") String userAgent,
@HeaderParam("X-Forwarded-For") String forwardedFor, @HeaderParam("X-Forwarded-For") String forwardedFor,
@QueryParam("online") boolean online, @QueryParam("online") boolean online,
@QueryParam("ts") long timestamp, @QueryParam("ts") long timestamp,
@QueryParam("story") boolean isStory,
@NotNull @Valid MultiRecipientMessage multiRecipientMessage) { @NotNull @Valid MultiRecipientMessage multiRecipientMessage) {
Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients()) Map<UUID, Account> uuidToAccountMap = Arrays.stream(multiRecipientMessage.getRecipients())
.map(Recipient::getUuid) .map(Recipient::getUuid)
.distinct() .distinct()
.collect(Collectors.toUnmodifiableMap(Function.identity(), uuid -> { .collect(Collectors.toUnmodifiableMap(
Optional<Account> account = accountsManager.getByAccountIdentifier(uuid); Function.identity(),
if (account.isEmpty()) { uuid -> accountsManager
throw new WebApplicationException(Status.NOT_FOUND); .getByAccountIdentifier(uuid)
} .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND))));
return account.get();
}));
checkAccessKeys(accessKeys, uuidToAccountMap);
final Map<Account, HashSet<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap = // Stories will be checked by the client; we bypass access checks here for stories.
Arrays if (!isStory) {
.stream(multiRecipientMessage.getRecipients()) checkAccessKeys(accessKeys, uuidToAccountMap);
.collect(Collectors.toMap( }
recipient -> uuidToAccountMap.get(recipient.getUuid()),
recipient -> new HashSet<>( final Map<Account, Set<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
Collections.singletonList(new Pair<>(recipient.getDeviceId(), recipient.getRegistrationId()))), buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, uuidToAccountMap);
(a, b) -> {
a.addAll(b); // We might filter out all the recipients of a story (if none have enabled stories).
return a; // In this case there is no error so we should just return 200 now.
} if (isStory && accountToDeviceIdAndRegistrationIdMap.isEmpty()) {
)); return Response.ok(new SendMultiRecipientMessageResponse(new LinkedList<>())).build();
}
Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>(); Collection<AccountMismatchedDevices> accountMismatchedDevices = new ArrayList<>();
Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>(); Collection<AccountStaleDevices> accountStaleDevices = new ArrayList<>();
uuidToAccountMap.values().forEach(account -> { uuidToAccountMap.values().forEach(account -> {
final Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap.get(account).stream().map(Pair::first)
.collect(Collectors.toSet()); if (isStory) {
checkStoryRateLimit(account);
}
Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap
.getOrDefault(account, Collections.emptySet())
.stream()
.map(Pair::first)
.collect(Collectors.toSet());
try { try {
DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet()); DestinationDeviceValidator.validateCompleteDeviceList(account, deviceIds, Collections.emptySet());
@ -351,8 +401,8 @@ public class MessageController {
Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow(); Device destinationDevice = destinationAccount.getDevice(recipient.getDeviceId()).orElseThrow();
sentMessageCounter.increment(); sentMessageCounter.increment();
try { try {
sendMessage(destinationAccount, destinationDevice, timestamp, online, recipient, sendCommonPayloadMessage(destinationAccount, destinationDevice, timestamp, online, isStory,
multiRecipientMessage.getCommonPayload()); recipient, multiRecipientMessage.getCommonPayload());
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
uuids404.add(destinationAccount.getUuid()); uuids404.add(destinationAccount.getUuid());
} }
@ -367,6 +417,10 @@ public class MessageController {
} }
private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) { private void checkAccessKeys(CombinedUnidentifiedSenderAccessKeys accessKeys, Map<UUID, Account> uuidToAccountMap) {
// We should not have null access keys when checking access; bail out early.
if (accessKeys == null) {
throw new WebApplicationException(Status.UNAUTHORIZED);
}
AtomicBoolean throwUnauthorized = new AtomicBoolean(false); AtomicBoolean throwUnauthorized = new AtomicBoolean(false);
byte[] empty = new byte[16]; byte[] empty = new byte[16];
final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]); final Optional<byte[]> UNRESTRICTED_UNIDENTIFIED_ACCESS_KEY = Optional.of(new byte[16]);
@ -405,8 +459,11 @@ public class MessageController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth, public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth,
@HeaderParam(Stories.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader,
@HeaderParam("User-Agent") String userAgent) { @HeaderParam("User-Agent") String userAgent) {
boolean shouldReceiveStories = Stories.parseReceiveStoriesHeader(receiveStoriesHeader);
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent); pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
final OutgoingMessageEntityList outgoingMessages; final OutgoingMessageEntityList outgoingMessages;
@ -416,7 +473,12 @@ public class MessageController {
auth.getAuthenticatedDevice().getId(), auth.getAuthenticatedDevice().getId(),
false); false);
outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream() Stream<Envelope> envelopes = messagesAndHasMore.first().stream();
if (!shouldReceiveStories) {
envelopes = envelopes.filter(e -> !e.getStory());
}
outgoingMessages = new OutgoingMessageEntityList(envelopes
.map(OutgoingMessageEntity::fromEnvelope) .map(OutgoingMessageEntity::fromEnvelope)
.peek(outgoingMessageEntity -> MessageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), .peek(outgoingMessageEntity -> MessageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(),
outgoingMessageEntity)) outgoingMessageEntity))
@ -514,12 +576,13 @@ public class MessageController {
.build(); .build();
} }
private void sendMessage(Optional<AuthenticatedAccount> source, private void sendIndividualMessage(Optional<AuthenticatedAccount> source,
Account destinationAccount, Account destinationAccount,
Device destinationDevice, Device destinationDevice,
UUID destinationUuid, UUID destinationUuid,
long timestamp, long timestamp,
boolean online, boolean online,
boolean story,
boolean urgent, boolean urgent,
IncomingMessage incomingMessage, IncomingMessage incomingMessage,
String userAgentString) String userAgentString)
@ -532,6 +595,7 @@ public class MessageController {
source.map(AuthenticatedAccount::getAccount).orElse(null), source.map(AuthenticatedAccount::getAccount).orElse(null),
source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null),
timestamp == 0 ? System.currentTimeMillis() : timestamp, timestamp == 0 ? System.currentTimeMillis() : timestamp,
story,
urgent); urgent);
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString);
@ -545,10 +609,11 @@ public class MessageController {
} }
} }
private void sendMessage(Account destinationAccount, private void sendCommonPayloadMessage(Account destinationAccount,
Device destinationDevice, Device destinationDevice,
long timestamp, long timestamp,
boolean online, boolean online,
boolean story,
Recipient recipient, Recipient recipient,
byte[] commonPayload) throws NoSuchUserException { byte[] commonPayload) throws NoSuchUserException {
try { try {
@ -566,6 +631,7 @@ public class MessageController {
.setTimestamp(timestamp == 0 ? serverTimestamp : timestamp) .setTimestamp(timestamp == 0 ? serverTimestamp : timestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setContent(ByteString.copyFrom(payload)) .setContent(ByteString.copyFrom(payload))
.setStory(story)
.setDestinationUuid(destinationAccount.getUuid().toString()); .setDestinationUuid(destinationAccount.getUuid().toString());
messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online); messageSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build(), online);
@ -578,7 +644,14 @@ public class MessageController {
} }
} }
private void checkRateLimit(AuthenticatedAccount source, Account destination, String userAgent) private void checkStoryRateLimit(Account destination) {
try {
rateLimiters.getMessagesLimiter().validate(destination.getUuid());
} catch (final RateLimitExceededException e) {
}
}
private void checkMessageRateLimit(AuthenticatedAccount source, Account destination, String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber()); final String senderCountryCode = Util.getCountryCode(source.getAccount().getNumber());

View File

@ -15,6 +15,10 @@ public class AccountMismatchedDevices {
@JsonProperty @JsonProperty
public final MismatchedDevices devices; public final MismatchedDevices devices;
public String toString() {
return "AccountMismatchedDevices(" + uuid + ", " + devices + ")";
}
public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) { public AccountMismatchedDevices(final UUID uuid, final MismatchedDevices devices) {
this.uuid = uuid; this.uuid = uuid;
this.devices = devices; this.devices = devices;

View File

@ -15,6 +15,10 @@ public class AccountStaleDevices {
@JsonProperty @JsonProperty
public final StaleDevices devices; public final StaleDevices devices;
public String toString() {
return "AccountStaleDevices(" + uuid + ", " + devices + ")";
}
public AccountStaleDevices(final UUID uuid, final StaleDevices devices) { public AccountStaleDevices(final UUID uuid, final StaleDevices devices) {
this.uuid = uuid; this.uuid = uuid;
this.devices = devices; this.devices = devices;

View File

@ -17,6 +17,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio
@Nullable Account sourceAccount, @Nullable Account sourceAccount,
@Nullable Long sourceDeviceId, @Nullable Long sourceDeviceId,
final long timestamp, final long timestamp,
final boolean story,
final boolean urgent) { final boolean urgent) {
final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type());
@ -31,6 +32,7 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio
.setTimestamp(timestamp) .setTimestamp(timestamp)
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setDestinationUuid(destinationUuid.toString()) .setDestinationUuid(destinationUuid.toString())
.setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) { if (sourceAccount != null && sourceDeviceId != null) {

View File

@ -11,14 +11,15 @@ import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
public record IncomingMessageList(@NotNull @Valid List<@NotNull IncomingMessage> messages, public record IncomingMessageList(@NotNull @Valid List<@NotNull IncomingMessage> messages,
boolean online, boolean urgent, long timestamp) { boolean online, boolean urgent, boolean story, long timestamp) {
@JsonCreator @JsonCreator
public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNull IncomingMessage> messages, public IncomingMessageList(@JsonProperty("messages") @NotNull @Valid List<@NotNull IncomingMessage> messages,
@JsonProperty("online") boolean online, @JsonProperty("online") boolean online,
@JsonProperty("urgent") Boolean urgent, @JsonProperty("urgent") Boolean urgent,
@JsonProperty("story") Boolean story,
@JsonProperty("timestamp") long timestamp) { @JsonProperty("timestamp") long timestamp) {
this(messages, online, urgent == null || urgent, timestamp); this(messages, online, urgent == null || urgent, story != null && story, timestamp);
} }
} }

View File

@ -21,6 +21,10 @@ public class MismatchedDevices {
@VisibleForTesting @VisibleForTesting
public MismatchedDevices() {} public MismatchedDevices() {}
public String toString() {
return "MismatchedDevices(" + missingDevices + ", " + extraDevices + ")";
}
public MismatchedDevices(List<Long> missingDevices, List<Long> extraDevices) { public MismatchedDevices(List<Long> missingDevices, List<Long> extraDevices) {
this.missingDevices = missingDevices; this.missingDevices = missingDevices;
this.extraDevices = extraDevices; this.extraDevices = extraDevices;

View File

@ -13,7 +13,7 @@ import javax.annotation.Nullable;
public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice, public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullable UUID sourceUuid, int sourceDevice,
UUID destinationUuid, @Nullable UUID updatedPni, byte[] content, UUID destinationUuid, @Nullable UUID updatedPni, byte[] content,
long serverTimestamp, boolean urgent) { long serverTimestamp, boolean urgent, boolean story) {
public MessageProtos.Envelope toEnvelope() { public MessageProtos.Envelope toEnvelope() {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
@ -22,6 +22,7 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab
.setServerTimestamp(serverTimestamp()) .setServerTimestamp(serverTimestamp())
.setDestinationUuid(destinationUuid().toString()) .setDestinationUuid(destinationUuid().toString())
.setServerGuid(guid().toString()) .setServerGuid(guid().toString())
.setStory(story)
.setUrgent(urgent); .setUrgent(urgent);
if (sourceUuid() != null) { if (sourceUuid() != null) {
@ -51,7 +52,8 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab
envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null, envelope.hasUpdatedPni() ? UUID.fromString(envelope.getUpdatedPni()) : null,
envelope.getContent().toByteArray(), envelope.getContent().toByteArray(),
envelope.getServerTimestamp(), envelope.getServerTimestamp(),
envelope.getUrgent()); envelope.getUrgent(),
envelope.getStory());
} }
@Override @Override
@ -63,16 +65,23 @@ public record OutgoingMessageEntity(UUID guid, int type, long timestamp, @Nullab
return false; return false;
} }
final OutgoingMessageEntity that = (OutgoingMessageEntity) o; final OutgoingMessageEntity that = (OutgoingMessageEntity) o;
return type == that.type && timestamp == that.timestamp && sourceDevice == that.sourceDevice return guid.equals(that.guid) &&
&& serverTimestamp == that.serverTimestamp && guid.equals(that.guid) type == that.type &&
&& Objects.equals(sourceUuid, that.sourceUuid) && destinationUuid.equals(that.destinationUuid) timestamp == that.timestamp &&
&& Objects.equals(updatedPni, that.updatedPni) && Arrays.equals(content, that.content) && urgent == that.urgent; Objects.equals(sourceUuid, that.sourceUuid) &&
sourceDevice == that.sourceDevice &&
destinationUuid.equals(that.destinationUuid) &&
Objects.equals(updatedPni, that.updatedPni) &&
Arrays.equals(content, that.content) &&
serverTimestamp == that.serverTimestamp &&
urgent == that.urgent &&
story == that.story;
} }
@Override @Override
public int hashCode() { public int hashCode() {
int result = Objects.hash(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, int result = Objects.hash(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni,
serverTimestamp, urgent); serverTimestamp, urgent, story);
result = 31 * result + Arrays.hashCode(content); result = 31 * result + Arrays.hashCode(content);
return result; return result;
} }

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
@ -16,6 +17,15 @@ public class SendMultiRecipientMessageResponse {
public SendMultiRecipientMessageResponse() { public SendMultiRecipientMessageResponse() {
} }
public String toString() {
return "SendMultiRecipientMessageResponse(" + uuids404 + ")";
}
@VisibleForTesting
public List<UUID> getUUIDs404() {
return this.uuids404;
}
public SendMultiRecipientMessageResponse(final List<UUID> uuids404) { public SendMultiRecipientMessageResponse(final List<UUID> uuids404) {
this.uuids404 = uuids404; this.uuids404 = uuids404;
} }

View File

@ -16,6 +16,10 @@ public class StaleDevices {
public StaleDevices() {} public StaleDevices() {}
public String toString() {
return "StaleDevices(" + staleDevices + ")";
}
public StaleDevices(List<Long> staleDevices) { public StaleDevices(List<Long> staleDevices) {
this.staleDevices = staleDevices; this.staleDevices = staleDevices;
} }

View File

@ -37,6 +37,8 @@ public class RateLimiters {
private final RateLimiter checkAccountExistenceLimiter; private final RateLimiter checkAccountExistenceLimiter;
private final RateLimiter storiesLimiter;
public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) { public RateLimiters(RateLimitsConfiguration config, FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination",
config.getSmsDestination().getBucketSize(), config.getSmsDestination().getBucketSize(),
@ -118,6 +120,10 @@ public class RateLimiters {
this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence", this.checkAccountExistenceLimiter = new RateLimiter(cacheCluster, "checkAccountExistence",
config.getCheckAccountExistence().getBucketSize(), config.getCheckAccountExistence().getBucketSize(),
config.getCheckAccountExistence().getLeakRatePerMinute()); config.getCheckAccountExistence().getLeakRatePerMinute());
this.storiesLimiter = new RateLimiter(cacheCluster, "stories",
config.getStories().getBucketSize(),
config.getStories().getLeakRatePerMinute());
} }
public RateLimiter getAllocateDeviceLimiter() { public RateLimiter getAllocateDeviceLimiter() {
@ -199,4 +205,6 @@ public class RateLimiters {
public RateLimiter getCheckAccountExistenceLimiter() { public RateLimiter getCheckAccountExistenceLimiter() {
return checkAccountExistenceLimiter; return checkAccountExistenceLimiter;
} }
public RateLimiter getStoriesLimiter() { return storiesLimiter; }
} }

View File

@ -15,6 +15,7 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;

View File

@ -332,7 +332,16 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
final Envelope envelope = messages.get(i); final Envelope envelope = messages.get(i);
final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { final boolean discard;
if (isDesktopClient && envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE) {
discard = true;
} else if (envelope.getStory() && !client.shouldDeliverStories()) {
discard = true;
} else {
discard = false;
}
if (discard) {
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp()); messagesManager.delete(auth.getAccount().getUuid(), device.getId(), messageGuid, envelope.getServerTimestamp());
discardedMessagesMeter.mark(); discardedMessagesMeter.mark();

View File

@ -42,7 +42,8 @@ message Envelope {
optional string destination_uuid = 13; optional string destination_uuid = 13;
optional bool urgent = 14 [default=true]; optional bool urgent = 14 [default=true];
optional string updated_pni = 15; optional string updated_pni = 15;
// next: 16 optional bool story = 16; // indicates that the content is a story.
// next: 17
} }
message ProvisioningUuid { message ProvisioningUuid {

View File

@ -32,15 +32,21 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@ -62,11 +68,13 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -80,6 +88,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.Stories;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest { class MessageControllerTest {
@ -87,10 +96,20 @@ class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID(); private static final UUID SINGLE_DEVICE_UUID = UUID.randomUUID();
private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID(); private static final UUID SINGLE_DEVICE_PNI = UUID.randomUUID();
private static final int SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID(); private static final UUID MULTI_DEVICE_UUID = UUID.randomUUID();
private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID(); private static final UUID MULTI_DEVICE_PNI = UUID.randomUUID();
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
private static final byte[] UNIDENTIFIED_ACCESS_BYTES = "0123456789abcdef".getBytes();
private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID(); private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
@ -116,6 +135,7 @@ class MessageControllerTest {
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class) .addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(MultiRecipientMessageProvider.class)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
@ -125,18 +145,18 @@ class MessageControllerTest {
@BeforeEach @BeforeEach
void setup() { void setup() {
final List<Device> singleDeviceList = List.of( final List<Device> singleDeviceList = List.of(
generateTestDevice(1, 111, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis()) generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, 1111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
); );
final List<Device> multiDeviceList = List.of( final List<Device> multiDeviceList = List.of(
generateTestDevice(1, 222, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, 2222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(2, 333, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, 3333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(3, 444, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, 4444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
); );
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes()); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, "1234".getBytes()); Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES);
internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes()); internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
@ -171,7 +191,8 @@ class MessageControllerTest {
rateLimiters, rateLimiters,
rateLimiter, rateLimiter,
pushNotificationManager, pushNotificationManager,
reportMessageManager reportMessageManager,
multiRecipientMessageExecutor
); );
} }
@ -270,7 +291,7 @@ class MessageControllerTest {
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"), .put(Entity.entity(SystemMapper.getMapper().readValue(jsonFixture("fixtures/current_message_single_device.json"),
IncomingMessageList.class), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
@ -412,8 +433,9 @@ class MessageControllerTest {
verifyNoMoreInteractions(messageSender); verifyNoMoreInteractions(messageSender);
} }
@Test @ParameterizedTest
void testGetMessages() { @MethodSource
void testGetMessages(boolean receiveStories) {
final long timestampOne = 313377; final long timestampOne = 313377;
final long timestampTwo = 313388; final long timestampTwo = 313388;
@ -424,19 +446,15 @@ class MessageControllerTest {
final UUID updatedPniOne = UUID.randomUUID(); final UUID updatedPniOne = UUID.randomUUID();
List<Envelope> messages = List.of( List<Envelope> envelopes = List.of(
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2, generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2,
AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0), AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false),
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2, generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2,
AuthHelper.VALID_UUID, null, null, 0) AuthHelper.VALID_UUID, null, null, 0, true)
); );
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages.stream()
.map(OutgoingMessageEntity::fromEnvelope)
.toList(), false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(messages, false)); .thenReturn(new Pair<>(envelopes, false));
final String userAgent = "Test-UA"; final String userAgent = "Test-UA";
@ -444,27 +462,39 @@ class MessageControllerTest {
resources.getJerseyTest().target("/v1/messages/") resources.getJerseyTest().target("/v1/messages/")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(Stories.X_SIGNAL_RECEIVE_STORIES, receiveStories ? "true" : "false")
.header("USer-Agent", userAgent) .header("USer-Agent", userAgent)
.accept(MediaType.APPLICATION_JSON_TYPE) .accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class); .get(OutgoingMessageEntityList.class);
assertEquals(response.messages().size(), 2); List<OutgoingMessageEntity> messages = response.messages();
int expectedSize = receiveStories ? 2 : 1;
assertEquals(expectedSize, messages.size());
assertEquals(response.messages().get(0).timestamp(), timestampOne); OutgoingMessageEntity first = messages.get(0);
assertEquals(response.messages().get(1).timestamp(), timestampTwo); assertEquals(first.timestamp(), timestampOne);
assertEquals(first.guid(), messageGuidOne);
assertEquals(first.sourceUuid(), sourceUuid);
assertEquals(updatedPniOne, first.updatedPni());
assertEquals(response.messages().get(0).guid(), messageGuidOne); if (receiveStories) {
assertEquals(response.messages().get(1).guid(), messageGuidTwo); OutgoingMessageEntity second = messages.get(1);
assertEquals(second.timestamp(), timestampTwo);
assertEquals(response.messages().get(0).sourceUuid(), sourceUuid); assertEquals(second.guid(), messageGuidTwo);
assertEquals(response.messages().get(1).sourceUuid(), sourceUuid); assertEquals(second.sourceUuid(), sourceUuid);
assertNull(second.updatedPni());
assertEquals(updatedPniOne, response.messages().get(0).updatedPni()); }
assertNull(response.messages().get(1).updatedPni());
verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent); verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent);
} }
private static Stream<Arguments> testGetMessages() {
return Stream.of(
Arguments.of(true),
Arguments.of(false)
);
}
@Test @Test
void testGetMessagesBadAuth() { void testGetMessagesBadAuth() {
final long timestampOne = 313377; final long timestampOne = 313377;
@ -644,9 +674,9 @@ class MessageControllerTest {
resources.getJerseyTest() resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID)) .target(String.format("/v1/messages/%s", SINGLE_DEVICE_UUID))
.request() .request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString("1234".getBytes())) .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList( .put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, false,
System.currentTimeMillis()), System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
@ -686,14 +716,166 @@ class MessageControllerTest {
); );
} }
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, int deviceId, int regId) throws Exception {
bb.putLong(msb); // uuid (first 8 bytes)
bb.putLong(lsb); // uuid (last 8 bytes)
int x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
bb.put((byte)(x & 0x7f));
x = x >>> 7;
} while (x != 0);
bb.putShort((short) regId); // registration id short
bb.put(new byte[48]); // key material (48 bytes)
}
private InputStream initializeMultiPayload(UUID recipientUUID, byte[] buffer) throws Exception {
// initialize a binary payload according to our wire format
ByteBuffer bb = ByteBuffer.wrap(buffer);
bb.order(ByteOrder.BIG_ENDIAN);
// determine how many recipient/device pairs we will be writing
int count;
if (recipientUUID == MULTI_DEVICE_UUID) { count = 2; }
else if (recipientUUID == SINGLE_DEVICE_UUID) { count = 1; }
else { throw new Exception("unknown UUID: " + recipientUUID); }
// first write the header header
bb.put(MultiRecipientMessageProvider.VERSION); // version byte
bb.put((byte)count); // count varint, # of active devices for this user
long msb = recipientUUID.getMostSignificantBits();
long lsb = recipientUUID.getLeastSignificantBits();
// write the recipient data for each recipient/device pair
if (recipientUUID == MULTI_DEVICE_UUID) {
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1);
writeMultiPayloadRecipient(bb, msb, lsb, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2);
} else {
writeMultiPayloadRecipient(bb, msb, lsb, SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1);
}
// now write the actual message body (empty for now)
bb.put(new byte[39]); // payload (variable but >= 32, 39 bytes here)
// return the input stream
return new ByteArrayInputStream(buffer, 0, bb.position());
}
@ParameterizedTest
@MethodSource
void testMultiRecipientMessage(UUID recipientUUID, boolean authorize, boolean isStory) throws Exception {
// initialize our binary payload and create an input stream
byte[] buffer = new byte[2048];
InputStream stream = initializeMultiPayload(recipientUUID, buffer);
// set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request
Invocation.Builder bldr = resources
.getJerseyTest()
.target("/v1/messages/multi_recipient")
.queryParam("online", true)
.queryParam("ts", 1663798405641L)
.queryParam("story", isStory)
.request()
.header("User-Agent", "FIXME");
// add access header if needed
if (authorize) {
String encodedBytes = Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES);
bldr = bldr.header(OptionalAccess.UNIDENTIFIED, encodedBytes);
}
// make the PUT request
Response response = bldr.put(entity);
// We have a 2x2x2 grid of possible situations based on:
// - recipient enabled stories?
// - sender is authorized?
// - message is a story?
if (recipientUUID == MULTI_DEVICE_UUID) {
// This is the case where the recipient has enabled stories.
if(isStory) {
// We are sending a story, so we ignore access checks and expect this
// to go out to both the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// We are not sending a story, so we need to do access checks.
if (authorize) {
// When authorized we send a message to the recipient's devices.
checkGoodMultiRecipientResponse(response, 2);
} else {
// When forbidden, we return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
} else {
// This is the case where the recipient has not enabled stories.
if (isStory) {
// We are sending a story, so we ignore access checks.
// this recipient has one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// We are not sending a story so check access.
if (authorize) {
// If allowed, send a message to the recipient's one device.
checkGoodMultiRecipientResponse(response, 1);
} else {
// If forbidden, return a 401 error.
checkBadMultiRecipientResponse(response, 401);
}
}
}
}
// Arguments here are: recipient-UUID, is-authorized?, is-story?
private static Stream<Arguments> testMultiRecipientMessage() {
return Stream.of(
Arguments.of(MULTI_DEVICE_UUID, false, true),
Arguments.of(MULTI_DEVICE_UUID, false, false),
Arguments.of(SINGLE_DEVICE_UUID, false, true),
Arguments.of(SINGLE_DEVICE_UUID, false, false),
Arguments.of(MULTI_DEVICE_UUID, true, true),
Arguments.of(MULTI_DEVICE_UUID, true, false),
Arguments.of(SINGLE_DEVICE_UUID, true, true),
Arguments.of(SINGLE_DEVICE_UUID, true, false)
);
}
private void checkBadMultiRecipientResponse(Response response, int expectedCode) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(expectedCode)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
verify(multiRecipientMessageExecutor, never()).invokeAll(any());
}
private void checkGoodMultiRecipientResponse(Response response, int expectedCount) throws Exception {
assertThat("Unexpected response", response.getStatus(), is(equalTo(200)));
verify(messageSender, never()).sendMessage(any(), any(), any(), anyBoolean());
ArgumentCaptor<List<Callable<Void>>> captor = ArgumentCaptor.forClass(List.class);
verify(multiRecipientMessageExecutor, times(1)).invokeAll(captor.capture());
assert (captor.getValue().size() == expectedCount);
SendMultiRecipientMessageResponse smrmr = response.readEntity(SendMultiRecipientMessageResponse.class);
assert (smrmr.getUUIDs404().isEmpty());
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type)) .setType(MessageProtos.Envelope.Type.forNumber(type))
.setTimestamp(timestamp) .setTimestamp(timestamp)
.setServerTimestamp(serverTimestamp) .setServerTimestamp(serverTimestamp)
.setDestinationUuid(destinationUuid.toString()) .setDestinationUuid(destinationUuid.toString())
.setStory(story)
.setServerGuid(guid.toString()); .setServerGuid(guid.toString());
if (sourceUuid != null) { if (sourceUuid != null) {

View File

@ -36,7 +36,8 @@ class OutgoingMessageEntityTest {
updatedPni, updatedPni,
messageContent, messageContent,
serverTimestamp, serverTimestamp,
true); true,
false);
assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope())); assertEquals(outgoingMessageEntity, OutgoingMessageEntity.fromEnvelope(outgoingMessageEntity.toEnvelope()));
} }

View File

@ -70,7 +70,7 @@ class MessageMetricsTest {
} }
private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) { private OutgoingMessageEntity createOutgoingMessageEntity(UUID destinationUuid) {
return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true); return new OutgoingMessageEntity(UUID.randomUUID(), 1, 1L, null, 1, destinationUuid, null, new byte[]{}, 1, true, false);
} }
@Test @Test

View File

@ -0,0 +1,15 @@
package org.whispersystems.websocket;
/**
* Class containing constants and shared logic for handling stories.
* <p>
* In particular, it defines the way we interpret the X-Signal-Receive-Stories header
* which is used by both WebSockets and by the REST API.
*/
public class Stories {
public final static String X_SIGNAL_RECEIVE_STORIES = "X-Signal-Receive-Stories";
public static boolean parseReceiveStoriesHeader(String s) {
return "true".equals(s);
}
}

View File

@ -35,13 +35,12 @@ public class WebSocketClient {
public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint,
WebSocketMessageFactory messageFactory, WebSocketMessageFactory messageFactory,
Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) {
{ this.session = session;
this.session = session; this.remoteEndpoint = remoteEndpoint;
this.remoteEndpoint = remoteEndpoint; this.messageFactory = messageFactory;
this.messageFactory = messageFactory;
this.pendingRequestMapper = pendingRequestMapper; this.pendingRequestMapper = pendingRequestMapper;
this.created = System.currentTimeMillis(); this.created = System.currentTimeMillis();
} }
public CompletableFuture<WebSocketResponseMessage> sendRequest(String verb, String path, public CompletableFuture<WebSocketResponseMessage> sendRequest(String verb, String path,
@ -92,6 +91,11 @@ public class WebSocketClient {
session.close(code, message); session.close(code, message);
} }
public boolean shouldDeliverStories() {
String value = session.getUpgradeRequest().getHeader(Stories.X_SIGNAL_RECEIVE_STORIES);
return Stories.parseReceiveStoriesHeader(value);
}
public void hardDisconnectQuietly() { public void hardDisconnectQuietly() {
try { try {
session.disconnect(); session.disconnect();

View File

@ -653,12 +653,14 @@ class WebSocketResourceProviderTest {
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse(); assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse();
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue(); assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue();
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue(); assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue();
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Signal-Receive-Stories")).isTrue();
} }
@Test @Test
void testShouldIncludeRequestMessageHeader() { void testShouldIncludeRequestMessageHeader() {
assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse(); assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse();
assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue(); assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue();
assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Signal-Receive-Stories")).isTrue();
} }
@Test @Test