From 8c74ad073b463a6f66140a83451a4fa4711791e2 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 9 Jan 2014 15:20:06 -1000 Subject: [PATCH] Rework messages API to fail if you miss some deviceIds per number --- .../controllers/FederationController.java | 52 ++++++-- .../controllers/MessageController.java | 116 ++++++++++++------ .../entities/MessageResponse.java | 11 +- .../federation/FederatedClient.java | 13 +- .../textsecuregcm/push/PushSender.java | 43 +++++-- .../textsecuregcm/util/Pair.java | 14 ++- 6 files changed, 183 insertions(+), 66 deletions(-) diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java index 0682b290c..2dcb96f00 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java @@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; +import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.federation.FederatedPeer; @@ -34,9 +35,11 @@ import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.UrlSigner; import org.whispersystems.textsecuregcm.util.Util; +import javax.print.attribute.standard.Media; import javax.validation.Valid; import javax.ws.rs.Consumes; import javax.ws.rs.GET; @@ -49,8 +52,12 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; import java.net.URL; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.Set; @Path("/v1/federation") public class FederationController { @@ -102,22 +109,49 @@ public class FederationController { @PUT @Path("/message") @Consumes(MediaType.APPLICATION_JSON) - public void relayMessage(@Auth FederatedPeer peer, @Valid RelayMessage message) + @Produces(MediaType.APPLICATION_JSON) + public MessageResponse relayMessage(@Auth FederatedPeer peer, @Valid List messages) throws IOException { try { - OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal()) - .toBuilder() - .setRelay(peer.getName()) - .build(); + Map>> destinations = new HashMap<>(); - pushSender.sendMessage(message.getDestination(), message.getDestinationDeviceId(), signal); + for (RelayMessage message : messages) { + Pair> deviceIds = destinations.get(message.getDestination()); + if (deviceIds == null) { + deviceIds = new Pair>(true, new HashSet()); + destinations.put(message.getDestination(), deviceIds); + } + deviceIds.second().add(message.getDestinationDeviceId()); + } + + Map, Account> accountCache = new HashMap<>(); + List numbersMissingDevices = new LinkedList<>(); + pushSender.fillLocalAccountsCache(destinations, accountCache, numbersMissingDevices); + + List success = new LinkedList<>(); + List failure = new LinkedList<>(numbersMissingDevices); + + for (RelayMessage message : messages) { + Account account = accountCache.get(new Pair<>(message.getDestination(), message.getDestinationDeviceId())); + if (account == null) + continue; + OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal()) + .toBuilder() + .setRelay(peer.getName()) + .build(); + try { + pushSender.sendMessage(account, signal); + } catch (NoSuchUserException e) { + logger.info("No such user", e); + failure.add(message.getDestination()); + } + } + + return new MessageResponse(success, failure, numbersMissingDevices); } catch (InvalidProtocolBufferException ipe) { logger.warn("ProtoBuf", ipe); throw new WebApplicationException(Response.status(400).build()); - } catch (NoSuchUserException e) { - logger.debug("No User", e); - throw new WebApplicationException(Response.status(404).build()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index e6359e4ba..f666cf759 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -35,26 +35,31 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import org.whispersystems.textsecuregcm.entities.MessageResponse; +import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.federation.FederatedClient; import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.IterablePair; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; +import javax.annotation.Nullable; import javax.servlet.AsyncContext; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.BufferedReader; import java.io.IOException; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -99,14 +104,20 @@ public class MessageController extends HttpServlet { rateLimiters.getMessagesLimiter().validate(sender.getNumber()); + + Map, Account> accountCache = new HashMap<>(); + List numbersMissingDevices = new LinkedList<>(); + List incomingMessages = messages.getMessages(); List outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), - incomingMessages); + incomingMessages, + accountCache, + numbersMissingDevices); IterablePair listPair = new IterablePair<>(incomingMessages, outgoingMessages); - handleAsyncDelivery(timerContext, req.startAsync(), listPair); + handleAsyncDelivery(timerContext, req.startAsync(), listPair, accountCache, numbersMissingDevices); } catch (AuthenticationException e) { failureMeter.mark(); timerContext.stop(); @@ -129,32 +140,68 @@ public class MessageController extends HttpServlet { private void handleAsyncDelivery(final TimerContext timerContext, final AsyncContext context, - final IterablePair listPair) + final IterablePair listPair, + final Map, Account> accountCache, + final List numbersMissingDevices) { executor.submit(new Runnable() { @Override public void run() { List success = new LinkedList<>(); - List failure = new LinkedList<>(); + List failure = new LinkedList<>(numbersMissingDevices); HttpServletResponse response = (HttpServletResponse) context.getResponse(); try { + Map>> relayMessages = new HashMap<>(); for (Pair messagePair : listPair) { String destination = messagePair.first().getDestination(); long destinationDeviceId = messagePair.first().getDestinationDeviceId(); String relay = messagePair.first().getRelay(); + if (Util.isEmpty(relay)) { + try { + pushSender.sendMessage(accountCache.get(new Pair<>(destination, destinationDeviceId)), messagePair.second()); + } catch (NoSuchUserException e) { + logger.debug("No such user", e); + failure.add(destination); + } + } else { + Set> messageSet = relayMessages.get(relay); + if (messageSet == null) { + messageSet = new HashSet<>(); + relayMessages.put(relay, messageSet); + } + messageSet.add(messagePair); + } + success.add(destination); + } + + for (Map.Entry>> messagesForRelay : relayMessages.entrySet()) { try { - if (Util.isEmpty(relay)) sendLocalMessage(destination, destinationDeviceId, messagePair.second()); - else sendRelayMessage(relay, destination, destinationDeviceId, messagePair.second()); - success.add(destination); - } catch (NoSuchUserException e) { - logger.debug("No such user", e); - failure.add(destination); + FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey()); + + List messages = new LinkedList<>(); + for (Pair message : messagesForRelay.getValue()) { + messages.add(new RelayMessage(message.first().getDestination(), + message.first().getDestinationDeviceId(), + message.second().toByteArray())); + } + + MessageResponse relayResponse = client.sendMessages(messages); + for (String string : relayResponse.getSuccess()) + success.add(string); + for (String string : relayResponse.getFailure()) + failure.add(string); + for (String string : relayResponse.getNumbersMissingDevices()) + numbersMissingDevices.add(string); + } catch (NoSuchPeerException e) { + logger.info("No such peer", e); + for (Pair messagePair : messagesForRelay.getValue()) + failure.add(messagePair.first().getDestination()); } } - byte[] responseData = serializeResponse(new MessageResponse(success, failure)); + byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices)); response.setContentLength(responseData.length); response.getOutputStream().write(responseData); context.complete(); @@ -171,36 +218,33 @@ public class MessageController extends HttpServlet { }); } - private void sendLocalMessage(String destination, long destinationDeviceId, OutgoingMessageSignal outgoingMessage) - throws IOException, NoSuchUserException - { - pushSender.sendMessage(destination, destinationDeviceId, outgoingMessage); - } - - private void sendRelayMessage(String relay, String destination, long destinationDeviceId, OutgoingMessageSignal outgoingMessage) - throws IOException, NoSuchUserException - { - try { - FederatedClient client = federatedClientManager.getClient(relay); - client.sendMessage(destination, destinationDeviceId, outgoingMessage); - } catch (NoSuchPeerException e) { - logger.info("No such peer", e); - throw new NoSuchUserException(e); - } - } - - private List getOutgoingMessageSignals(String number, - List incomingMessages) + /** + * @param accountCache is a map from Pair to the account + */ + @Nullable + private List getOutgoingMessageSignals(String sourceNumber, + List incomingMessages, + Map, Account> accountCache, + List numbersMissingDevices) { List outgoingMessages = new LinkedList<>(); - Set destinations = new HashSet<>(); - for (IncomingMessage incoming : incomingMessages) - destinations.add(incoming.getDestination()); + // # local deviceIds + Map>> destinations = new HashMap<>(); + for (IncomingMessage incoming : incomingMessages) { + Pair> deviceIds = destinations.get(incoming.getDestination()); + if (deviceIds == null) { + deviceIds = new Pair>(Util.isEmpty(incoming.getRelay()), new HashSet()); + destinations.put(incoming.getDestination(), deviceIds); + } + deviceIds.second().add(incoming.getDestinationDeviceId()); + } + + pushSender.fillLocalAccountsCache(destinations, accountCache, numbersMissingDevices); for (IncomingMessage incoming : incomingMessages) { OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder(); outgoingMessage.setType(incoming.getType()); - outgoingMessage.setSource(number); + outgoingMessage.setSource(sourceNumber); byte[] messageBody = getMessageBody(incoming); @@ -212,7 +256,7 @@ public class MessageController extends HttpServlet { int index = 0; - for (String destination : destinations) { + for (String destination : destinations.keySet()) { if (!destination.equals(incoming.getDestination())) { outgoingMessage.setDestinations(index++, destination); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java index a04e96f10..fc8870a42 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java @@ -21,10 +21,12 @@ import java.util.List; public class MessageResponse { private List success; private List failure; + private List missingDeviceIds; - public MessageResponse(List success, List failure) { - this.success = success; - this.failure = failure; + public MessageResponse(List success, List failure, List missingDeviceIds) { + this.success = success; + this.failure = failure; + this.missingDeviceIds = missingDeviceIds; } public MessageResponse() {} @@ -37,4 +39,7 @@ public class MessageResponse { return failure; } + public List getNumbersMissingDevices() { + return missingDeviceIds; + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/federation/FederatedClient.java b/src/main/java/org/whispersystems/textsecuregcm/federation/FederatedClient.java index a008e43e3..eca167789 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/federation/FederatedClient.java +++ b/src/main/java/org/whispersystems/textsecuregcm/federation/FederatedClient.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; +import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; @@ -140,23 +141,21 @@ public class FederatedClient { } } - public void sendMessage(String destination, long destinationDeviceId, OutgoingMessageSignal message) - throws IOException, NoSuchUserException + public MessageResponse sendMessages(List messages) + throws IOException { try { WebResource resource = client.resource(peer.getUrl()).path(RELAY_MESSAGE_PATH); ClientResponse response = resource.type(MediaType.APPLICATION_JSON) .header("Authorization", authorizationHeader) - .entity(new RelayMessage(destination, destinationDeviceId, message.toByteArray())) + .entity(messages) .put(ClientResponse.class); - if (response.getStatus() == 404) { - throw new NoSuchUserException("No remote user: " + destination); - } - if (response.getStatus() != 200 && response.getStatus() != 204) { throw new IOException("Bad response: " + response.getStatus()); } + + return response.getEntity(MessageResponse.class); } catch (UniformInterfaceException | ClientHandlerException e) { logger.warn("sendMessage", e); throw new IOException(e); diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index 07e6b34bd..088b2b297 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -28,19 +28,21 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.StoredMessageManager; +import org.whispersystems.textsecuregcm.util.Pair; import java.io.IOException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.util.List; +import java.util.Map; +import java.util.Set; public class PushSender { private final Logger logger = LoggerFactory.getLogger(PushSender.class); private final AccountsManager accounts; - private final DirectoryManager directory; private final GCMSender gcmSender; private final APNSender apnSender; @@ -54,23 +56,44 @@ public class PushSender { throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException { this.accounts = accounts; - this.directory = directory; this.storedMessageManager = storedMessageManager; this.gcmSender = new GCMSender(gcmConfiguration.getApiKey()); this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey()); } - public void sendMessage(String destination, long destinationDeviceId, MessageProtos.OutgoingMessageSignal outgoingMessage) + /** + * For each local destination in destinations, either adds all its accounts to accountCache or adds the number to + * numbersMissingDevices, if the deviceIds list don't match what is required. + * @param destinations Map from number to Pair<localNumber, Set<deviceIds>> + * @param accountCache Map from <number, deviceId> to account + * @param numbersMissingDevices list of numbers missing devices + */ + public void fillLocalAccountsCache(Map>> destinations, Map, Account> accountCache, List numbersMissingDevices) { + for (Map.Entry>> destination : destinations.entrySet()) { + if (destination.getValue().first()) { + String number = destination.getKey(); + List accountList = accounts.getAllByNumber(number); + Set deviceIdsIncluded = destination.getValue().second(); + if (accountList.size() != deviceIdsIncluded.size()) + numbersMissingDevices.add(number); + else { + for (Account account : accountList) { + if (!deviceIdsIncluded.contains(account.getDeviceId())) { + numbersMissingDevices.add(number); + break; + } + } + for (Account account : accountList) + accountCache.put(new Pair<>(number, account.getDeviceId()), account); + } + } + } + } + + public void sendMessage(Account account, MessageProtos.OutgoingMessageSignal outgoingMessage) throws IOException, NoSuchUserException { - Optional accountOptional = accounts.get(destination, destinationDeviceId); - - if (!accountOptional.isPresent()) { - throw new NoSuchUserException("No such local destination: " + destination); - } - Account account = accountOptional.get(); - String signalingKey = account.getSignalingKey(); EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey); diff --git a/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java b/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java index 601f5a229..795854969 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java +++ b/src/main/java/org/whispersystems/textsecuregcm/util/Pair.java @@ -1,10 +1,12 @@ package org.whispersystems.textsecuregcm.util; +import static com.google.common.base.Objects.equal; + public class Pair { private final T1 v1; private final T2 v2; - Pair(T1 v1, T2 v2) { + public Pair(T1 v1, T2 v2) { this.v1 = v1; this.v2 = v2; } @@ -16,4 +18,14 @@ public class Pair { public T2 second(){ return v2; } + + public boolean equals(Object o) { + return o instanceof Pair && + equal(((Pair) o).first(), first()) && + equal(((Pair) o).second(), second()); + } + + public int hashCode() { + return first().hashCode() ^ second().hashCode(); + } }