Simplify message handling by returning early and throwing out maps
This commit is contained in:
parent
7af3c51cc4
commit
eedaa8b3f4
|
@ -60,6 +60,8 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static com.google.common.base.Preconditions.checkState;
|
||||||
|
|
||||||
@Path("/v1/federation")
|
@Path("/v1/federation")
|
||||||
public class FederationController {
|
public class FederationController {
|
||||||
|
|
||||||
|
@ -125,17 +127,24 @@ public class FederationController {
|
||||||
deviceIds.add(message.getDestinationDeviceId());
|
deviceIds.add(message.getDestinationDeviceId());
|
||||||
}
|
}
|
||||||
|
|
||||||
Pair<Map<String, Account>, List<String>> accountsForDevices = accounts.getAccountsForDevices(localDestinations);
|
List<Account> localAccounts = null;
|
||||||
|
try {
|
||||||
|
localAccounts = accounts.getAccountsForDevices(localDestinations);
|
||||||
|
} catch (MissingDevicesException e) {
|
||||||
|
return new MessageResponse(e.missingNumbers);
|
||||||
|
}
|
||||||
|
|
||||||
Map<String, Account> localAccounts = accountsForDevices.first();
|
|
||||||
List<String> numbersMissingDevices = accountsForDevices.second();
|
|
||||||
List<String> success = new LinkedList<>();
|
List<String> success = new LinkedList<>();
|
||||||
List<String> failure = new LinkedList<>(numbersMissingDevices);
|
List<String> failure = new LinkedList<>();
|
||||||
|
|
||||||
for (RelayMessage message : messages) {
|
for (RelayMessage message : messages) {
|
||||||
Account destinationAccount = localAccounts.get(message.getDestination());
|
Account destinationAccount = null;
|
||||||
if (destinationAccount == null)
|
for (Account account : localAccounts)
|
||||||
continue;
|
if (account.getNumber().equals(message.getDestination()))
|
||||||
|
destinationAccount= account;
|
||||||
|
|
||||||
|
checkState(destinationAccount != null);
|
||||||
|
|
||||||
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
|
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
|
||||||
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
|
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
|
||||||
.toBuilder()
|
.toBuilder()
|
||||||
|
@ -150,7 +159,7 @@ public class FederationController {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new MessageResponse(success, failure, numbersMissingDevices);
|
return new MessageResponse(success, failure);
|
||||||
} catch (InvalidProtocolBufferException ipe) {
|
} catch (InvalidProtocolBufferException ipe) {
|
||||||
logger.warn("ProtoBuf", ipe);
|
logger.warn("ProtoBuf", ipe);
|
||||||
throw new WebApplicationException(Response.status(400).build());
|
throw new WebApplicationException(Response.status(400).build());
|
||||||
|
|
|
@ -45,7 +45,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
import org.whispersystems.textsecuregcm.util.Base64;
|
import org.whispersystems.textsecuregcm.util.Base64;
|
||||||
import org.whispersystems.textsecuregcm.util.IterablePair;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
|
@ -54,12 +53,10 @@ import javax.servlet.AsyncContext;
|
||||||
import javax.servlet.http.HttpServlet;
|
import javax.servlet.http.HttpServlet;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import javax.ws.rs.Path;
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -117,16 +114,9 @@ public class MessageController extends HttpServlet {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Device sender = authenticate(req);
|
Device sender = authenticate(req);
|
||||||
IncomingMessageList messages = parseIncomingMessages(req);
|
|
||||||
|
|
||||||
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
|
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
|
||||||
|
|
||||||
List<String> numbersMissingDevices = new LinkedList<>();
|
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
|
||||||
|
|
||||||
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages =
|
|
||||||
getOutgoingMessageSignals(sender.getNumber(), messages.getMessages(), numbersMissingDevices);
|
|
||||||
|
|
||||||
handleAsyncDelivery(timerContext, req.startAsync(), outgoingMessages, numbersMissingDevices);
|
|
||||||
} catch (AuthenticationException e) {
|
} catch (AuthenticationException e) {
|
||||||
failureMeter.mark();
|
failureMeter.mark();
|
||||||
timerContext.stop();
|
timerContext.stop();
|
||||||
|
@ -149,19 +139,32 @@ public class MessageController extends HttpServlet {
|
||||||
|
|
||||||
private void handleAsyncDelivery(final TimerContext timerContext,
|
private void handleAsyncDelivery(final TimerContext timerContext,
|
||||||
final AsyncContext context,
|
final AsyncContext context,
|
||||||
final List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> listPair,
|
final Device sender,
|
||||||
final List<String> numbersMissingDevices)
|
final IncomingMessageList messages)
|
||||||
{
|
{
|
||||||
executor.submit(new Runnable() {
|
executor.submit(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
List<String> success = new LinkedList<>();
|
List<String> success = new LinkedList<>();
|
||||||
List<String> failure = new LinkedList<>(numbersMissingDevices);
|
List<String> failure = new LinkedList<>();
|
||||||
HttpServletResponse response = (HttpServletResponse) context.getResponse();
|
HttpServletResponse response = (HttpServletResponse) context.getResponse();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
|
||||||
|
try {
|
||||||
|
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
|
||||||
|
} catch (MissingDevicesException e) {
|
||||||
|
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
|
||||||
|
response.setContentLength(responseData.length);
|
||||||
|
response.getOutputStream().write(responseData);
|
||||||
|
context.complete();
|
||||||
|
failureMeter.mark();
|
||||||
|
timerContext.stop();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
|
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
|
||||||
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : listPair) {
|
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
|
||||||
String relay = messagePair.first().relay;
|
String relay = messagePair.first().relay;
|
||||||
|
|
||||||
if (Util.isEmpty(relay)) {
|
if (Util.isEmpty(relay)) {
|
||||||
|
@ -199,8 +202,6 @@ public class MessageController extends HttpServlet {
|
||||||
success.add(string);
|
success.add(string);
|
||||||
for (String string : relayResponse.getFailure())
|
for (String string : relayResponse.getFailure())
|
||||||
failure.add(string);
|
failure.add(string);
|
||||||
for (String string : relayResponse.getNumbersMissingDevices())
|
|
||||||
numbersMissingDevices.add(string);
|
|
||||||
} catch (NoSuchPeerException e) {
|
} catch (NoSuchPeerException e) {
|
||||||
logger.info("No such peer", e);
|
logger.info("No such peer", e);
|
||||||
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
|
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
|
||||||
|
@ -208,7 +209,7 @@ public class MessageController extends HttpServlet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices));
|
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
|
||||||
response.setContentLength(responseData.length);
|
response.setContentLength(responseData.length);
|
||||||
response.getOutputStream().write(responseData);
|
response.getOutputStream().write(responseData);
|
||||||
context.complete();
|
context.complete();
|
||||||
|
@ -232,30 +233,16 @@ public class MessageController extends HttpServlet {
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
|
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
|
||||||
List<IncomingMessage> incomingMessages,
|
List<IncomingMessage> incomingMessages)
|
||||||
List<String> numbersMissingDevices)
|
throws MissingDevicesException
|
||||||
{
|
{
|
||||||
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
|
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
|
||||||
Map<String, Set<Long>> localDestinations = new HashMap<>();
|
|
||||||
|
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
|
||||||
|
|
||||||
Set<String> destinationNumbers = new HashSet<>();
|
Set<String> destinationNumbers = new HashSet<>();
|
||||||
for (IncomingMessage incoming : incomingMessages) {
|
for (IncomingMessage incoming : incomingMessages)
|
||||||
destinationNumbers.add(incoming.getDestination());
|
destinationNumbers.add(incoming.getDestination());
|
||||||
if (!Util.isEmpty(incoming.getRelay()))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
|
|
||||||
if (deviceIds == null) {
|
|
||||||
deviceIds = new HashSet<>();
|
|
||||||
localDestinations.put(incoming.getDestination(), deviceIds);
|
|
||||||
}
|
|
||||||
deviceIds.add(incoming.getDestinationDeviceId());
|
|
||||||
}
|
|
||||||
|
|
||||||
Pair<Map<String, Account>, List<String>> accountsForDevices = accountsManager.getAccountsForDevices(localDestinations);
|
|
||||||
|
|
||||||
Map<String, Account> localAccounts = accountsForDevices.first();
|
|
||||||
for (String number : accountsForDevices.second())
|
|
||||||
numbersMissingDevices.add(number);
|
|
||||||
|
|
||||||
for (IncomingMessage incoming : incomingMessages) {
|
for (IncomingMessage incoming : incomingMessages) {
|
||||||
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
|
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
|
||||||
|
@ -281,7 +268,14 @@ public class MessageController extends HttpServlet {
|
||||||
if (!Util.isEmpty(incoming.getRelay()))
|
if (!Util.isEmpty(incoming.getRelay()))
|
||||||
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
|
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
|
||||||
else {
|
else {
|
||||||
Account destination = localAccounts.get(incoming.getDestination());
|
Account destination = null;
|
||||||
|
for (Account account : localAccounts) {
|
||||||
|
if (account.getNumber().equals(incoming.getDestination())) {
|
||||||
|
destination = account;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (destination != null)
|
if (destination != null)
|
||||||
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
|
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
|
||||||
}
|
}
|
||||||
|
@ -293,6 +287,24 @@ public class MessageController extends HttpServlet {
|
||||||
return outgoingMessages;
|
return outgoingMessages;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
|
||||||
|
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
|
||||||
|
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
|
||||||
|
Map<String, Set<Long>> localDestinations = new HashMap<>();
|
||||||
|
for (IncomingMessage incoming : incomingMessages) {
|
||||||
|
if (!Util.isEmpty(incoming.getRelay()))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
|
||||||
|
if (deviceIds == null) {
|
||||||
|
deviceIds = new HashSet<>();
|
||||||
|
localDestinations.put(incoming.getDestination(), deviceIds);
|
||||||
|
}
|
||||||
|
deviceIds.add(incoming.getDestinationDeviceId());
|
||||||
|
}
|
||||||
|
return localDestinations;
|
||||||
|
}
|
||||||
|
|
||||||
private byte[] getMessageBody(IncomingMessage message) {
|
private byte[] getMessageBody(IncomingMessage message) {
|
||||||
try {
|
try {
|
||||||
return Base64.decode(message.getBody());
|
return Base64.decode(message.getBody());
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
package org.whispersystems.textsecuregcm.controllers;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class MissingDevicesException extends Exception {
|
||||||
|
public Set<String> missingNumbers;
|
||||||
|
public MissingDevicesException(Set<String> missingNumbers) {
|
||||||
|
this.missingNumbers = missingNumbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,16 +16,25 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
public class MessageResponse {
|
public class MessageResponse {
|
||||||
private List<String> success;
|
private List<String> success;
|
||||||
private List<String> failure;
|
private List<String> failure;
|
||||||
private List<String> missingDeviceIds;
|
private Set<String> missingDeviceIds;
|
||||||
|
|
||||||
public MessageResponse(List<String> success, List<String> failure, List<String> missingDeviceIds) {
|
public MessageResponse(List<String> success, List<String> failure) {
|
||||||
this.success = success;
|
this.success = success;
|
||||||
this.failure = failure;
|
this.failure = failure;
|
||||||
|
this.missingDeviceIds = new HashSet<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public MessageResponse(Set<String> missingDeviceIds) {
|
||||||
|
this.success = new LinkedList<>();
|
||||||
|
this.failure = new LinkedList<>(missingDeviceIds);
|
||||||
this.missingDeviceIds = missingDeviceIds;
|
this.missingDeviceIds = missingDeviceIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,11 +44,23 @@ public class MessageResponse {
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setSuccess(List<String> success) {
|
||||||
|
this.success = success;
|
||||||
|
}
|
||||||
|
|
||||||
public List<String> getFailure() {
|
public List<String> getFailure() {
|
||||||
return failure;
|
return failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> getNumbersMissingDevices() {
|
public void setFailure(List<String> failure) {
|
||||||
|
this.failure = failure;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getNumbersMissingDevices() {
|
||||||
return missingDeviceIds;
|
return missingDeviceIds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setNumbersMissingDevices(Set<String> numbersMissingDevices) {
|
||||||
|
this.missingDeviceIds = numbersMissingDevices;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import com.google.common.base.Optional;
|
import com.google.common.base.Optional;
|
||||||
import net.spy.memcached.MemcachedClient;
|
import net.spy.memcached.MemcachedClient;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.MissingDevicesException;
|
||||||
import org.whispersystems.textsecuregcm.entities.ClientContact;
|
import org.whispersystems.textsecuregcm.entities.ClientContact;
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
@ -119,36 +120,30 @@ public class AccountsManager {
|
||||||
return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices));
|
return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Account> getAllAccounts(Set<String> numbers) {
|
private List<Account> getAllAccounts(Set<String> numbers) {
|
||||||
//TODO: ONE QUERY
|
//TODO: ONE QUERY
|
||||||
Map<String, Account> result = new HashMap<>();
|
List<Account> accounts = new LinkedList<>();
|
||||||
for (String number : numbers) {
|
for (String number : numbers) {
|
||||||
Optional<Account> account = getAccount(number);
|
Optional<Account> account = getAccount(number);
|
||||||
if (account.isPresent())
|
if (account.isPresent())
|
||||||
result.put(number, account.get());
|
accounts.add(account.get());
|
||||||
}
|
}
|
||||||
return result;
|
return accounts;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Pair<Map<String, Account>, List<String>> getAccountsForDevices(Map<String, Set<Long>> destinations) {
|
public List<Account> getAccountsForDevices(Map<String, Set<Long>> destinations) throws MissingDevicesException {
|
||||||
List<String> numbersMissingDevices = new LinkedList<>();
|
Set<String> numbersMissingDevices = new HashSet<>(destinations.keySet());
|
||||||
Map<String, Account> localAccounts = getAllAccounts(destinations.keySet());
|
List<Account> localAccounts = getAllAccounts(destinations.keySet());
|
||||||
|
|
||||||
for (String number : destinations.keySet()) {
|
for (Account account : localAccounts){
|
||||||
if (localAccounts.get(number) == null)
|
if (account.hasAllDeviceIds(destinations.get(account.getNumber())))
|
||||||
numbersMissingDevices.add(number);
|
numbersMissingDevices.remove(account.getNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
Iterator<Account> localAccountIterator = localAccounts.values().iterator();
|
if (!numbersMissingDevices.isEmpty())
|
||||||
while (localAccountIterator.hasNext()) {
|
throw new MissingDevicesException(numbersMissingDevices);
|
||||||
Account account = localAccountIterator.next();
|
|
||||||
if (!account.hasAllDeviceIds(destinations.get(account.getNumber()))) {
|
|
||||||
numbersMissingDevices.add(account.getNumber());
|
|
||||||
localAccountIterator.remove();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Pair<>(localAccounts, numbersMissingDevices);
|
return localAccounts;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateDirectory(Device device) {
|
private void updateDirectory(Device device) {
|
||||||
|
|
|
@ -1,133 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.tests.controllers;
|
|
||||||
|
|
||||||
import com.google.common.base.Optional;
|
|
||||||
import com.sun.jersey.api.client.ClientResponse;
|
|
||||||
import com.sun.jersey.api.client.GenericType;
|
|
||||||
import com.yammer.dropwizard.testing.ResourceTest;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.FederationController;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.KeysController;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageResponse;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.PreKey;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.RelayMessage;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
|
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
|
||||||
import org.whispersystems.textsecuregcm.push.PushSender;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Device;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Keys;
|
|
||||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
|
||||||
import org.whispersystems.textsecuregcm.util.UrlSigner;
|
|
||||||
|
|
||||||
import javax.ws.rs.core.MediaType;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.LinkedList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import static org.fest.assertions.api.Assertions.assertThat;
|
|
||||||
import static org.mockito.Mockito.eq;
|
|
||||||
import static org.mockito.Mockito.isA;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.times;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class FederatedControllerTest extends ResourceTest {
|
|
||||||
|
|
||||||
private final String EXISTS_NUMBER = "+14152222222";
|
|
||||||
private final String NOT_EXISTS_NUMBER = "+14152222220";
|
|
||||||
|
|
||||||
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, AuthHelper.DEFAULT_DEVICE_ID, 1234, "test1", "test2", false);
|
|
||||||
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false);
|
|
||||||
|
|
||||||
private final Keys keys = mock(Keys.class);
|
|
||||||
private final PushSender pushSender = mock(PushSender.class);
|
|
||||||
|
|
||||||
Device[] fakeDevice;
|
|
||||||
Account existsAccount;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void setUpResources() {
|
|
||||||
addProvider(AuthHelper.getAuthenticator());
|
|
||||||
|
|
||||||
RateLimiters rateLimiters = mock(RateLimiters.class);
|
|
||||||
RateLimiter rateLimiter = mock(RateLimiter.class );
|
|
||||||
AccountsManager accounts = mock(AccountsManager.class);
|
|
||||||
|
|
||||||
fakeDevice = new Device[2];
|
|
||||||
fakeDevice[0] = mock(Device.class);
|
|
||||||
fakeDevice[1] = mock(Device.class);
|
|
||||||
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
|
|
||||||
|
|
||||||
Account account = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
|
|
||||||
|
|
||||||
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
|
|
||||||
|
|
||||||
when(keys.get(eq(EXISTS_NUMBER), isA(Account.class))).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2)));
|
|
||||||
when(keys.get(eq(NOT_EXISTS_NUMBER), isA(Account.class))).thenReturn(null);
|
|
||||||
|
|
||||||
when(fakeDevice[0].getDeviceId()).thenReturn(AuthHelper.DEFAULT_DEVICE_ID);
|
|
||||||
when(fakeDevice[1].getDeviceId()).thenReturn((long) 2);
|
|
||||||
|
|
||||||
Map<String, Set<Long>> VALID_SET = new HashMap<>();
|
|
||||||
VALID_SET.put(EXISTS_NUMBER, new HashSet<Long>(Arrays.asList((long)1)));
|
|
||||||
Map<String, Account> VALID_ACCOUNTS = new HashMap<>();
|
|
||||||
VALID_ACCOUNTS.put(EXISTS_NUMBER, account);
|
|
||||||
|
|
||||||
Map<String, Set<Long>> INVALID_SET = new HashMap<>();
|
|
||||||
INVALID_SET.put(NOT_EXISTS_NUMBER, new HashSet<Long>(Arrays.asList((long) 1)));
|
|
||||||
|
|
||||||
when(accounts.getAccountsForDevices(eq(VALID_SET)))
|
|
||||||
.thenReturn(new Pair<Map<String, Account>, List<String>>(VALID_ACCOUNTS, new LinkedList<String>()));
|
|
||||||
when(accounts.getAccountsForDevices(eq(INVALID_SET)))
|
|
||||||
.thenReturn(new Pair<Map<String, Account>, List<String>>(
|
|
||||||
new HashMap<String, Account>(), Arrays.asList(NOT_EXISTS_NUMBER)));
|
|
||||||
|
|
||||||
addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class)));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void validRequestsTest() throws Exception {
|
|
||||||
MessageResponse result = client().resource("/v1/federation/message")
|
|
||||||
.entity(new RelayMessage(EXISTS_NUMBER, 1, new byte[] {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF}))
|
|
||||||
.type(MediaType.APPLICATION_JSON)
|
|
||||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
|
|
||||||
.post(MessageResponse.class);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void invalidRequestTest() throws Exception {
|
|
||||||
ClientResponse response = client().resource("/v1/federation/message")
|
|
||||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
|
|
||||||
.post(ClientResponse.class);
|
|
||||||
|
|
||||||
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404);
|
|
||||||
verifyNoMoreInteractions(keys);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void unauthorizedRequestTest() throws Exception {
|
|
||||||
ClientResponse response =
|
|
||||||
client().resource("/v1/federation/message")
|
|
||||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD))
|
|
||||||
.post(ClientResponse.class);
|
|
||||||
|
|
||||||
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(401);
|
|
||||||
|
|
||||||
response =
|
|
||||||
client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
|
|
||||||
.post(ClientResponse.class);
|
|
||||||
|
|
||||||
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(401);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
package org.whispersystems.textsecuregcm.tests.controllers;
|
||||||
|
|
||||||
|
import com.google.common.base.Optional;
|
||||||
|
import com.sun.jersey.api.client.ClientResponse;
|
||||||
|
import com.sun.jersey.api.client.GenericType;
|
||||||
|
import com.yammer.dropwizard.testing.ResourceTest;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.FederationController;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.KeysController;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.MissingDevicesException;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.MessageResponse;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.PreKey;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.RelayMessage;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
|
import org.whispersystems.textsecuregcm.push.PushSender;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Keys;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
|
import org.whispersystems.textsecuregcm.util.UrlSigner;
|
||||||
|
|
||||||
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.fest.assertions.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.Mockito.eq;
|
||||||
|
import static org.mockito.Mockito.isA;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class FederatedControllerTest extends ResourceTest {
|
||||||
|
|
||||||
|
private final String EXISTS_NUMBER = "+14152222222";
|
||||||
|
private final String EXISTS_NUMBER_2 = "+14154444444";
|
||||||
|
private final String NOT_EXISTS_NUMBER = "+14152222220";
|
||||||
|
|
||||||
|
private final Keys keys = mock(Keys.class);
|
||||||
|
private final PushSender pushSender = mock(PushSender.class);
|
||||||
|
|
||||||
|
Device[] fakeDevice;
|
||||||
|
Account existsAccount;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void setUpResources() throws MissingDevicesException {
|
||||||
|
addProvider(AuthHelper.getAuthenticator());
|
||||||
|
|
||||||
|
RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||||
|
RateLimiter rateLimiter = mock(RateLimiter.class );
|
||||||
|
AccountsManager accounts = mock(AccountsManager.class);
|
||||||
|
|
||||||
|
fakeDevice = new Device[2];
|
||||||
|
fakeDevice[0] = new Device(42, EXISTS_NUMBER, 1, "", "", "", null, null, true, false);
|
||||||
|
fakeDevice[1] = new Device(43, EXISTS_NUMBER, 2, "", "", "", null, null, false, true);
|
||||||
|
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
|
||||||
|
|
||||||
|
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
|
||||||
|
|
||||||
|
Map<String, Set<Long>> validOneElementSet = new HashMap<>();
|
||||||
|
validOneElementSet.put(EXISTS_NUMBER_2, new HashSet<>(Arrays.asList((long) 1)));
|
||||||
|
List<Account> validOneAccount = Arrays.asList(new Account(EXISTS_NUMBER_2, true,
|
||||||
|
Arrays.asList(new Device(44, EXISTS_NUMBER_2, 1, "", "", "", null, null, true, false))));
|
||||||
|
|
||||||
|
Map<String, Set<Long>> validTwoElementsSet = new HashMap<>();
|
||||||
|
validTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1, (long) 2)));
|
||||||
|
List<Account> validTwoAccount = Arrays.asList(new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])));
|
||||||
|
|
||||||
|
Map<String, Set<Long>> invalidTwoElementsSet = new HashMap<>();
|
||||||
|
invalidTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1)));
|
||||||
|
|
||||||
|
when(accounts.getAccountsForDevices(eq(validOneElementSet))).thenReturn(validOneAccount);
|
||||||
|
when(accounts.getAccountsForDevices(eq(validTwoElementsSet))).thenReturn(validTwoAccount);
|
||||||
|
when(accounts.getAccountsForDevices(eq(invalidTwoElementsSet))).thenThrow(new MissingDevicesException(new HashSet<>(Arrays.asList(EXISTS_NUMBER))));
|
||||||
|
|
||||||
|
addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void validRequestsTest() throws Exception {
|
||||||
|
MessageResponse result = client().resource("/v1/federation/message")
|
||||||
|
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER_2, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
|
||||||
|
.type(MediaType.APPLICATION_JSON)
|
||||||
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
|
||||||
|
.put(MessageResponse.class);
|
||||||
|
|
||||||
|
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER_2));
|
||||||
|
assertThat(result.getFailure()).isEmpty();
|
||||||
|
assertThat(result.getNumbersMissingDevices()).isEmpty();
|
||||||
|
|
||||||
|
result = client().resource("/v1/federation/message")
|
||||||
|
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()),
|
||||||
|
new RelayMessage(EXISTS_NUMBER, 2, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
|
||||||
|
.type(MediaType.APPLICATION_JSON)
|
||||||
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
|
||||||
|
.put(MessageResponse.class);
|
||||||
|
|
||||||
|
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER, EXISTS_NUMBER + "." + 2));
|
||||||
|
assertThat(result.getFailure()).isEmpty();
|
||||||
|
assertThat(result.getNumbersMissingDevices()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void invalidRequestTest() throws Exception {
|
||||||
|
MessageResponse result = client().resource("/v1/federation/message")
|
||||||
|
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
|
||||||
|
.type(MediaType.APPLICATION_JSON)
|
||||||
|
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
|
||||||
|
.put(MessageResponse.class);
|
||||||
|
|
||||||
|
assertThat(result.getSuccess()).isEmpty();
|
||||||
|
assertThat(result.getFailure()).isEqualTo(Arrays.asList(EXISTS_NUMBER));
|
||||||
|
assertThat(result.getNumbersMissingDevices()).isEqualTo(new HashSet<>(Arrays.asList(EXISTS_NUMBER)));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue