Rework messages API to fail if you miss some deviceIds per number

This commit is contained in:
Matt Corallo 2014-01-09 15:20:06 -10:00
parent 918ef4a7ca
commit 8c74ad073b
6 changed files with 183 additions and 66 deletions

View File

@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
@ -34,9 +35,11 @@ import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.util.Util;
import javax.print.attribute.standard.Media;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
@ -49,8 +52,12 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Path("/v1/federation")
public class FederationController {
@ -102,22 +109,49 @@ public class FederationController {
@PUT
@Path("/message")
@Consumes(MediaType.APPLICATION_JSON)
public void relayMessage(@Auth FederatedPeer peer, @Valid RelayMessage message)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse relayMessage(@Auth FederatedPeer peer, @Valid List<RelayMessage> messages)
throws IOException
{
try {
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
.toBuilder()
.setRelay(peer.getName())
.build();
Map<String, Pair<Boolean, Set<Long>>> destinations = new HashMap<>();
pushSender.sendMessage(message.getDestination(), message.getDestinationDeviceId(), signal);
for (RelayMessage message : messages) {
Pair<Boolean, Set<Long>> deviceIds = destinations.get(message.getDestination());
if (deviceIds == null) {
deviceIds = new Pair<Boolean, Set<Long>>(true, new HashSet<Long>());
destinations.put(message.getDestination(), deviceIds);
}
deviceIds.second().add(message.getDestinationDeviceId());
}
Map<Pair<String, Long>, Account> accountCache = new HashMap<>();
List<String> numbersMissingDevices = new LinkedList<>();
pushSender.fillLocalAccountsCache(destinations, accountCache, numbersMissingDevices);
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>(numbersMissingDevices);
for (RelayMessage message : messages) {
Account account = accountCache.get(new Pair<>(message.getDestination(), message.getDestinationDeviceId()));
if (account == null)
continue;
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
.toBuilder()
.setRelay(peer.getName())
.build();
try {
pushSender.sendMessage(account, signal);
} catch (NoSuchUserException e) {
logger.info("No such user", e);
failure.add(message.getDestination());
}
}
return new MessageResponse(success, failure, numbersMissingDevices);
} catch (InvalidProtocolBufferException ipe) {
logger.warn("ProtoBuf", ipe);
throw new WebApplicationException(Response.status(400).build());
} catch (NoSuchUserException e) {
logger.debug("No User", e);
throw new WebApplicationException(Response.status(404).build());
}
}

View File

@ -35,26 +35,31 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.IterablePair;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@ -99,14 +104,20 @@ public class MessageController extends HttpServlet {
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
Map<Pair<String, Long>, Account> accountCache = new HashMap<>();
List<String> numbersMissingDevices = new LinkedList<>();
List<IncomingMessage> incomingMessages = messages.getMessages();
List<OutgoingMessageSignal> outgoingMessages = getOutgoingMessageSignals(sender.getNumber(),
incomingMessages);
incomingMessages,
accountCache,
numbersMissingDevices);
IterablePair<IncomingMessage, OutgoingMessageSignal> listPair = new IterablePair<>(incomingMessages,
outgoingMessages);
handleAsyncDelivery(timerContext, req.startAsync(), listPair);
handleAsyncDelivery(timerContext, req.startAsync(), listPair, accountCache, numbersMissingDevices);
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
@ -129,32 +140,68 @@ public class MessageController extends HttpServlet {
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final IterablePair<IncomingMessage, OutgoingMessageSignal> listPair)
final IterablePair<IncomingMessage, OutgoingMessageSignal> listPair,
final Map<Pair<String, Long>, Account> accountCache,
final List<String> numbersMissingDevices)
{
executor.submit(new Runnable() {
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
List<String> failure = new LinkedList<>(numbersMissingDevices);
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try {
Map<String, Set<Pair<IncomingMessage, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : listPair) {
String destination = messagePair.first().getDestination();
long destinationDeviceId = messagePair.first().getDestinationDeviceId();
String relay = messagePair.first().getRelay();
if (Util.isEmpty(relay)) {
try {
pushSender.sendMessage(accountCache.get(new Pair<>(destination, destinationDeviceId)), messagePair.second());
} catch (NoSuchUserException e) {
logger.debug("No such user", e);
failure.add(destination);
}
} else {
Set<Pair<IncomingMessage, OutgoingMessageSignal>> messageSet = relayMessages.get(relay);
if (messageSet == null) {
messageSet = new HashSet<>();
relayMessages.put(relay, messageSet);
}
messageSet.add(messagePair);
}
success.add(destination);
}
for (Map.Entry<String, Set<Pair<IncomingMessage, OutgoingMessageSignal>>> messagesForRelay : relayMessages.entrySet()) {
try {
if (Util.isEmpty(relay)) sendLocalMessage(destination, destinationDeviceId, messagePair.second());
else sendRelayMessage(relay, destination, destinationDeviceId, messagePair.second());
success.add(destination);
} catch (NoSuchUserException e) {
logger.debug("No such user", e);
failure.add(destination);
FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey());
List<RelayMessage> messages = new LinkedList<>();
for (Pair<IncomingMessage, OutgoingMessageSignal> message : messagesForRelay.getValue()) {
messages.add(new RelayMessage(message.first().getDestination(),
message.first().getDestinationDeviceId(),
message.second().toByteArray()));
}
MessageResponse relayResponse = client.sendMessages(messages);
for (String string : relayResponse.getSuccess())
success.add(string);
for (String string : relayResponse.getFailure())
failure.add(string);
for (String string : relayResponse.getNumbersMissingDevices())
numbersMissingDevices.add(string);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
failure.add(messagePair.first().getDestination());
}
}
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
@ -171,36 +218,33 @@ public class MessageController extends HttpServlet {
});
}
private void sendLocalMessage(String destination, long destinationDeviceId, OutgoingMessageSignal outgoingMessage)
throws IOException, NoSuchUserException
{
pushSender.sendMessage(destination, destinationDeviceId, outgoingMessage);
}
private void sendRelayMessage(String relay, String destination, long destinationDeviceId, OutgoingMessageSignal outgoingMessage)
throws IOException, NoSuchUserException
{
try {
FederatedClient client = federatedClientManager.getClient(relay);
client.sendMessage(destination, destinationDeviceId, outgoingMessage);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
throw new NoSuchUserException(e);
}
}
private List<OutgoingMessageSignal> getOutgoingMessageSignals(String number,
List<IncomingMessage> incomingMessages)
/**
* @param accountCache is a map from Pair<number, deviceId> to the account
*/
@Nullable
private List<OutgoingMessageSignal> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages,
Map<Pair<String, Long>, Account> accountCache,
List<String> numbersMissingDevices)
{
List<OutgoingMessageSignal> outgoingMessages = new LinkedList<>();
Set<String> destinations = new HashSet<>();
for (IncomingMessage incoming : incomingMessages)
destinations.add(incoming.getDestination());
// # local deviceIds
Map<String, Pair<Boolean, Set<Long>>> destinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
Pair<Boolean, Set<Long>> deviceIds = destinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new Pair<Boolean, Set<Long>>(Util.isEmpty(incoming.getRelay()), new HashSet<Long>());
destinations.put(incoming.getDestination(), deviceIds);
}
deviceIds.second().add(incoming.getDestinationDeviceId());
}
pushSender.fillLocalAccountsCache(destinations, accountCache, numbersMissingDevices);
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
outgoingMessage.setType(incoming.getType());
outgoingMessage.setSource(number);
outgoingMessage.setSource(sourceNumber);
byte[] messageBody = getMessageBody(incoming);
@ -212,7 +256,7 @@ public class MessageController extends HttpServlet {
int index = 0;
for (String destination : destinations) {
for (String destination : destinations.keySet()) {
if (!destination.equals(incoming.getDestination())) {
outgoingMessage.setDestinations(index++, destination);
}

View File

@ -21,10 +21,12 @@ import java.util.List;
public class MessageResponse {
private List<String> success;
private List<String> failure;
private List<String> missingDeviceIds;
public MessageResponse(List<String> success, List<String> failure) {
this.success = success;
this.failure = failure;
public MessageResponse(List<String> success, List<String> failure, List<String> missingDeviceIds) {
this.success = success;
this.failure = failure;
this.missingDeviceIds = missingDeviceIds;
}
public MessageResponse() {}
@ -37,4 +39,7 @@ public class MessageResponse {
return failure;
}
public List<String> getNumbersMissingDevices() {
return missingDeviceIds;
}
}

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
@ -140,23 +141,21 @@ public class FederatedClient {
}
}
public void sendMessage(String destination, long destinationDeviceId, OutgoingMessageSignal message)
throws IOException, NoSuchUserException
public MessageResponse sendMessages(List<RelayMessage> messages)
throws IOException
{
try {
WebResource resource = client.resource(peer.getUrl()).path(RELAY_MESSAGE_PATH);
ClientResponse response = resource.type(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.entity(new RelayMessage(destination, destinationDeviceId, message.toByteArray()))
.entity(messages)
.put(ClientResponse.class);
if (response.getStatus() == 404) {
throw new NoSuchUserException("No remote user: " + destination);
}
if (response.getStatus() != 200 && response.getStatus() != 204) {
throw new IOException("Bad response: " + response.getStatus());
}
return response.getEntity(MessageResponse.class);
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("sendMessage", e);
throw new IOException(e);

View File

@ -28,19 +28,21 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.util.Pair;
import java.io.IOException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class PushSender {
private final Logger logger = LoggerFactory.getLogger(PushSender.class);
private final AccountsManager accounts;
private final DirectoryManager directory;
private final GCMSender gcmSender;
private final APNSender apnSender;
@ -54,23 +56,44 @@ public class PushSender {
throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{
this.accounts = accounts;
this.directory = directory;
this.storedMessageManager = storedMessageManager;
this.gcmSender = new GCMSender(gcmConfiguration.getApiKey());
this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey());
}
public void sendMessage(String destination, long destinationDeviceId, MessageProtos.OutgoingMessageSignal outgoingMessage)
/**
* For each local destination in destinations, either adds all its accounts to accountCache or adds the number to
* numbersMissingDevices, if the deviceIds list don't match what is required.
* @param destinations Map from number to Pair&lt;localNumber, Set&lt;deviceIds&gt;&gt;
* @param accountCache Map from &lt;number, deviceId&gt; to account
* @param numbersMissingDevices list of numbers missing devices
*/
public void fillLocalAccountsCache(Map<String, Pair<Boolean, Set<Long>>> destinations, Map<Pair<String, Long>, Account> accountCache, List<String> numbersMissingDevices) {
for (Map.Entry<String, Pair<Boolean, Set<Long>>> destination : destinations.entrySet()) {
if (destination.getValue().first()) {
String number = destination.getKey();
List<Account> accountList = accounts.getAllByNumber(number);
Set<Long> deviceIdsIncluded = destination.getValue().second();
if (accountList.size() != deviceIdsIncluded.size())
numbersMissingDevices.add(number);
else {
for (Account account : accountList) {
if (!deviceIdsIncluded.contains(account.getDeviceId())) {
numbersMissingDevices.add(number);
break;
}
}
for (Account account : accountList)
accountCache.put(new Pair<>(number, account.getDeviceId()), account);
}
}
}
}
public void sendMessage(Account account, MessageProtos.OutgoingMessageSignal outgoingMessage)
throws IOException, NoSuchUserException
{
Optional<Account> accountOptional = accounts.get(destination, destinationDeviceId);
if (!accountOptional.isPresent()) {
throw new NoSuchUserException("No such local destination: " + destination);
}
Account account = accountOptional.get();
String signalingKey = account.getSignalingKey();
EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey);

View File

@ -1,10 +1,12 @@
package org.whispersystems.textsecuregcm.util;
import static com.google.common.base.Objects.equal;
public class Pair<T1, T2> {
private final T1 v1;
private final T2 v2;
Pair(T1 v1, T2 v2) {
public Pair(T1 v1, T2 v2) {
this.v1 = v1;
this.v2 = v2;
}
@ -16,4 +18,14 @@ public class Pair<T1, T2> {
public T2 second(){
return v2;
}
public boolean equals(Object o) {
return o instanceof Pair &&
equal(((Pair) o).first(), first()) &&
equal(((Pair) o).second(), second());
}
public int hashCode() {
return first().hashCode() ^ second().hashCode();
}
}