From eedaa8b3f43b7be57b05a07b8c7d867c0f9dec7d Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sat, 11 Jan 2014 16:30:37 -1000 Subject: [PATCH] Simplify message handling by returning early and throwing out maps --- .../controllers/FederationController.java | 25 ++-- .../controllers/MessageController.java | 90 +++++++----- .../controllers/MissingDevicesException.java | 11 ++ .../entities/MessageResponse.java | 27 +++- .../storage/AccountsManager.java | 33 ++--- .../controllers/FederatedControllerTest.java | 133 ------------------ .../controllers/FederatedControllerTest.java | 127 +++++++++++++++++ 7 files changed, 244 insertions(+), 202 deletions(-) create mode 100644 src/main/java/org/whispersystems/textsecuregcm/controllers/MissingDevicesException.java delete mode 100644 src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java create mode 100644 src/test/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java index 683d53a50..5c41af40b 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java @@ -60,6 +60,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import static com.google.common.base.Preconditions.checkState; + @Path("/v1/federation") public class FederationController { @@ -125,17 +127,24 @@ public class FederationController { deviceIds.add(message.getDestinationDeviceId()); } - Pair, List> accountsForDevices = accounts.getAccountsForDevices(localDestinations); + List localAccounts = null; + try { + localAccounts = accounts.getAccountsForDevices(localDestinations); + } catch (MissingDevicesException e) { + return new MessageResponse(e.missingNumbers); + } - Map localAccounts = accountsForDevices.first(); - List numbersMissingDevices = accountsForDevices.second(); List success = new LinkedList<>(); - List failure = new LinkedList<>(numbersMissingDevices); + List failure = new LinkedList<>(); for (RelayMessage message : messages) { - Account destinationAccount = localAccounts.get(message.getDestination()); - if (destinationAccount == null) - continue; + Account destinationAccount = null; + for (Account account : localAccounts) + if (account.getNumber().equals(message.getDestination())) + destinationAccount= account; + + checkState(destinationAccount != null); + Device device = destinationAccount.getDevice(message.getDestinationDeviceId()); OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal()) .toBuilder() @@ -150,7 +159,7 @@ public class FederationController { } } - return new MessageResponse(success, failure, numbersMissingDevices); + return new MessageResponse(success, failure); } catch (InvalidProtocolBufferException ipe) { logger.warn("ProtoBuf", ipe); throw new WebApplicationException(Response.status(400).build()); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 8adfd5fd9..ab1dbbe0d 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -45,7 +45,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Base64; -import org.whispersystems.textsecuregcm.util.IterablePair; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; @@ -54,12 +53,10 @@ import javax.servlet.AsyncContext; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.ws.rs.Path; import java.io.BufferedReader; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -117,16 +114,9 @@ public class MessageController extends HttpServlet { try { Device sender = authenticate(req); - IncomingMessageList messages = parseIncomingMessages(req); - rateLimiters.getMessagesLimiter().validate(sender.getNumber()); - List numbersMissingDevices = new LinkedList<>(); - - List> outgoingMessages = - getOutgoingMessageSignals(sender.getNumber(), messages.getMessages(), numbersMissingDevices); - - handleAsyncDelivery(timerContext, req.startAsync(), outgoingMessages, numbersMissingDevices); + handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req)); } catch (AuthenticationException e) { failureMeter.mark(); timerContext.stop(); @@ -149,19 +139,32 @@ public class MessageController extends HttpServlet { private void handleAsyncDelivery(final TimerContext timerContext, final AsyncContext context, - final List> listPair, - final List numbersMissingDevices) + final Device sender, + final IncomingMessageList messages) { executor.submit(new Runnable() { @Override public void run() { List success = new LinkedList<>(); - List failure = new LinkedList<>(numbersMissingDevices); + List failure = new LinkedList<>(); HttpServletResponse response = (HttpServletResponse) context.getResponse(); try { + List> outgoingMessages; + try { + outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages()); + } catch (MissingDevicesException e) { + byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers)); + response.setContentLength(responseData.length); + response.getOutputStream().write(responseData); + context.complete(); + failureMeter.mark(); + timerContext.stop(); + return; + } + Map>> relayMessages = new HashMap<>(); - for (Pair messagePair : listPair) { + for (Pair messagePair : outgoingMessages) { String relay = messagePair.first().relay; if (Util.isEmpty(relay)) { @@ -199,8 +202,6 @@ public class MessageController extends HttpServlet { success.add(string); for (String string : relayResponse.getFailure()) failure.add(string); - for (String string : relayResponse.getNumbersMissingDevices()) - numbersMissingDevices.add(string); } catch (NoSuchPeerException e) { logger.info("No such peer", e); for (Pair messagePair : messagesForRelay.getValue()) @@ -208,7 +209,7 @@ public class MessageController extends HttpServlet { } } - byte[] responseData = serializeResponse(new MessageResponse(success, failure, numbersMissingDevices)); + byte[] responseData = serializeResponse(new MessageResponse(success, failure)); response.setContentLength(responseData.length); response.getOutputStream().write(responseData); context.complete(); @@ -232,30 +233,16 @@ public class MessageController extends HttpServlet { @Nullable private List> getOutgoingMessageSignals(String sourceNumber, - List incomingMessages, - List numbersMissingDevices) + List incomingMessages) + throws MissingDevicesException { List> outgoingMessages = new LinkedList<>(); - Map> localDestinations = new HashMap<>(); + + List localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages)); + Set destinationNumbers = new HashSet<>(); - for (IncomingMessage incoming : incomingMessages) { + for (IncomingMessage incoming : incomingMessages) destinationNumbers.add(incoming.getDestination()); - if (!Util.isEmpty(incoming.getRelay())) - continue; - - Set deviceIds = localDestinations.get(incoming.getDestination()); - if (deviceIds == null) { - deviceIds = new HashSet<>(); - localDestinations.put(incoming.getDestination(), deviceIds); - } - deviceIds.add(incoming.getDestinationDeviceId()); - } - - Pair, List> accountsForDevices = accountsManager.getAccountsForDevices(localDestinations); - - Map localAccounts = accountsForDevices.first(); - for (String number : accountsForDevices.second()) - numbersMissingDevices.add(number); for (IncomingMessage incoming : incomingMessages) { OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder(); @@ -281,7 +268,14 @@ public class MessageController extends HttpServlet { if (!Util.isEmpty(incoming.getRelay())) device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId()); else { - Account destination = localAccounts.get(incoming.getDestination()); + Account destination = null; + for (Account account : localAccounts) { + if (account.getNumber().equals(incoming.getDestination())) { + destination = account; + break; + } + } + if (destination != null) device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId())); } @@ -293,6 +287,24 @@ public class MessageController extends HttpServlet { return outgoingMessages; } + // We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that + // we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages) + private Map> getLocalDestinations(List incomingMessages) { + Map> localDestinations = new HashMap<>(); + for (IncomingMessage incoming : incomingMessages) { + if (!Util.isEmpty(incoming.getRelay())) + continue; + + Set deviceIds = localDestinations.get(incoming.getDestination()); + if (deviceIds == null) { + deviceIds = new HashSet<>(); + localDestinations.put(incoming.getDestination(), deviceIds); + } + deviceIds.add(incoming.getDestinationDeviceId()); + } + return localDestinations; + } + private byte[] getMessageBody(IncomingMessage message) { try { return Base64.decode(message.getBody()); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/MissingDevicesException.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/MissingDevicesException.java new file mode 100644 index 000000000..f5fc3c93b --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/MissingDevicesException.java @@ -0,0 +1,11 @@ +package org.whispersystems.textsecuregcm.controllers; + +import java.util.List; +import java.util.Set; + +public class MissingDevicesException extends Exception { + public Set missingNumbers; + public MissingDevicesException(Set missingNumbers) { + this.missingNumbers = missingNumbers; + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java index fc8870a42..83e92d2ba 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/MessageResponse.java @@ -16,16 +16,25 @@ */ package org.whispersystems.textsecuregcm.entities; +import java.util.HashSet; +import java.util.LinkedList; import java.util.List; +import java.util.Set; public class MessageResponse { private List success; private List failure; - private List missingDeviceIds; + private Set missingDeviceIds; - public MessageResponse(List success, List failure, List missingDeviceIds) { + public MessageResponse(List success, List failure) { this.success = success; this.failure = failure; + this.missingDeviceIds = new HashSet<>(); + } + + public MessageResponse(Set missingDeviceIds) { + this.success = new LinkedList<>(); + this.failure = new LinkedList<>(missingDeviceIds); this.missingDeviceIds = missingDeviceIds; } @@ -35,11 +44,23 @@ public class MessageResponse { return success; } + public void setSuccess(List success) { + this.success = success; + } + public List getFailure() { return failure; } - public List getNumbersMissingDevices() { + public void setFailure(List failure) { + this.failure = failure; + } + + public Set getNumbersMissingDevices() { return missingDeviceIds; } + + public void setNumbersMissingDevices(Set numbersMissingDevices) { + this.missingDeviceIds = numbersMissingDevices; + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index e0fa2e484..3a154981d 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -19,6 +19,7 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.base.Optional; import net.spy.memcached.MemcachedClient; +import org.whispersystems.textsecuregcm.controllers.MissingDevicesException; import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; @@ -119,36 +120,30 @@ public class AccountsManager { return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices)); } - private Map getAllAccounts(Set numbers) { + private List getAllAccounts(Set numbers) { //TODO: ONE QUERY - Map result = new HashMap<>(); + List accounts = new LinkedList<>(); for (String number : numbers) { Optional account = getAccount(number); if (account.isPresent()) - result.put(number, account.get()); + accounts.add(account.get()); } - return result; + return accounts; } - public Pair, List> getAccountsForDevices(Map> destinations) { - List numbersMissingDevices = new LinkedList<>(); - Map localAccounts = getAllAccounts(destinations.keySet()); + public List getAccountsForDevices(Map> destinations) throws MissingDevicesException { + Set numbersMissingDevices = new HashSet<>(destinations.keySet()); + List localAccounts = getAllAccounts(destinations.keySet()); - for (String number : destinations.keySet()) { - if (localAccounts.get(number) == null) - numbersMissingDevices.add(number); + for (Account account : localAccounts){ + if (account.hasAllDeviceIds(destinations.get(account.getNumber()))) + numbersMissingDevices.remove(account.getNumber()); } - Iterator localAccountIterator = localAccounts.values().iterator(); - while (localAccountIterator.hasNext()) { - Account account = localAccountIterator.next(); - if (!account.hasAllDeviceIds(destinations.get(account.getNumber()))) { - numbersMissingDevices.add(account.getNumber()); - localAccountIterator.remove(); - } - } + if (!numbersMissingDevices.isEmpty()) + throw new MissingDevicesException(numbersMissingDevices); - return new Pair<>(localAccounts, numbersMissingDevices); + return localAccounts; } private void updateDirectory(Device device) { diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java deleted file mode 100644 index 59f026a81..000000000 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java +++ /dev/null @@ -1,133 +0,0 @@ -package org.whispersystems.textsecuregcm.tests.controllers; - -import com.google.common.base.Optional; -import com.sun.jersey.api.client.ClientResponse; -import com.sun.jersey.api.client.GenericType; -import com.yammer.dropwizard.testing.ResourceTest; -import org.junit.Test; -import org.whispersystems.textsecuregcm.controllers.FederationController; -import org.whispersystems.textsecuregcm.controllers.KeysController; -import org.whispersystems.textsecuregcm.entities.MessageResponse; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.RelayMessage; -import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; -import org.whispersystems.textsecuregcm.limits.RateLimiter; -import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.push.PushSender; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.Keys; -import org.whispersystems.textsecuregcm.tests.util.AuthHelper; -import org.whispersystems.textsecuregcm.util.Pair; -import org.whispersystems.textsecuregcm.util.UrlSigner; - -import javax.ws.rs.core.MediaType; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.fest.assertions.api.Assertions.assertThat; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -public class FederatedControllerTest extends ResourceTest { - - private final String EXISTS_NUMBER = "+14152222222"; - private final String NOT_EXISTS_NUMBER = "+14152222220"; - - private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, AuthHelper.DEFAULT_DEVICE_ID, 1234, "test1", "test2", false); - private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false); - - private final Keys keys = mock(Keys.class); - private final PushSender pushSender = mock(PushSender.class); - - Device[] fakeDevice; - Account existsAccount; - - @Override - protected void setUpResources() { - addProvider(AuthHelper.getAuthenticator()); - - RateLimiters rateLimiters = mock(RateLimiters.class); - RateLimiter rateLimiter = mock(RateLimiter.class ); - AccountsManager accounts = mock(AccountsManager.class); - - fakeDevice = new Device[2]; - fakeDevice[0] = mock(Device.class); - fakeDevice[1] = mock(Device.class); - existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])); - - Account account = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])); - - when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - - when(keys.get(eq(EXISTS_NUMBER), isA(Account.class))).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2))); - when(keys.get(eq(NOT_EXISTS_NUMBER), isA(Account.class))).thenReturn(null); - - when(fakeDevice[0].getDeviceId()).thenReturn(AuthHelper.DEFAULT_DEVICE_ID); - when(fakeDevice[1].getDeviceId()).thenReturn((long) 2); - - Map> VALID_SET = new HashMap<>(); - VALID_SET.put(EXISTS_NUMBER, new HashSet(Arrays.asList((long)1))); - Map VALID_ACCOUNTS = new HashMap<>(); - VALID_ACCOUNTS.put(EXISTS_NUMBER, account); - - Map> INVALID_SET = new HashMap<>(); - INVALID_SET.put(NOT_EXISTS_NUMBER, new HashSet(Arrays.asList((long) 1))); - - when(accounts.getAccountsForDevices(eq(VALID_SET))) - .thenReturn(new Pair, List>(VALID_ACCOUNTS, new LinkedList())); - when(accounts.getAccountsForDevices(eq(INVALID_SET))) - .thenReturn(new Pair, List>( - new HashMap(), Arrays.asList(NOT_EXISTS_NUMBER))); - - addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class))); - } - - @Test - public void validRequestsTest() throws Exception { - MessageResponse result = client().resource("/v1/federation/message") - .entity(new RelayMessage(EXISTS_NUMBER, 1, new byte[] {(byte) 0xDE, (byte) 0xAD, (byte) 0xBE, (byte) 0xEF})) - .type(MediaType.APPLICATION_JSON) - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN)) - .post(MessageResponse.class); - - } - - @Test - public void invalidRequestTest() throws Exception { - ClientResponse response = client().resource("/v1/federation/message") - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) - .post(ClientResponse.class); - - assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404); - verifyNoMoreInteractions(keys); - } - - @Test - public void unauthorizedRequestTest() throws Exception { - ClientResponse response = - client().resource("/v1/federation/message") - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD)) - .post(ClientResponse.class); - - assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(401); - - response = - client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER)) - .post(ClientResponse.class); - - assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(401); - } - -} \ No newline at end of file diff --git a/src/test/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java b/src/test/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java new file mode 100644 index 000000000..517168342 --- /dev/null +++ b/src/test/org/whispersystems/textsecuregcm/tests/controllers/FederatedControllerTest.java @@ -0,0 +1,127 @@ +package org.whispersystems.textsecuregcm.tests.controllers; + +import com.google.common.base.Optional; +import com.sun.jersey.api.client.ClientResponse; +import com.sun.jersey.api.client.GenericType; +import com.yammer.dropwizard.testing.ResourceTest; +import org.junit.Test; +import org.whispersystems.textsecuregcm.controllers.FederationController; +import org.whispersystems.textsecuregcm.controllers.KeysController; +import org.whispersystems.textsecuregcm.controllers.MissingDevicesException; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.MessageResponse; +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.RelayMessage; +import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.util.UrlSigner; + +import javax.ws.rs.core.MediaType; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.fest.assertions.api.Assertions.assertThat; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class FederatedControllerTest extends ResourceTest { + + private final String EXISTS_NUMBER = "+14152222222"; + private final String EXISTS_NUMBER_2 = "+14154444444"; + private final String NOT_EXISTS_NUMBER = "+14152222220"; + + private final Keys keys = mock(Keys.class); + private final PushSender pushSender = mock(PushSender.class); + + Device[] fakeDevice; + Account existsAccount; + + @Override + protected void setUpResources() throws MissingDevicesException { + addProvider(AuthHelper.getAuthenticator()); + + RateLimiters rateLimiters = mock(RateLimiters.class); + RateLimiter rateLimiter = mock(RateLimiter.class ); + AccountsManager accounts = mock(AccountsManager.class); + + fakeDevice = new Device[2]; + fakeDevice[0] = new Device(42, EXISTS_NUMBER, 1, "", "", "", null, null, true, false); + fakeDevice[1] = new Device(43, EXISTS_NUMBER, 2, "", "", "", null, null, false, true); + existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])); + + when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); + + Map> validOneElementSet = new HashMap<>(); + validOneElementSet.put(EXISTS_NUMBER_2, new HashSet<>(Arrays.asList((long) 1))); + List validOneAccount = Arrays.asList(new Account(EXISTS_NUMBER_2, true, + Arrays.asList(new Device(44, EXISTS_NUMBER_2, 1, "", "", "", null, null, true, false)))); + + Map> validTwoElementsSet = new HashMap<>(); + validTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1, (long) 2))); + List validTwoAccount = Arrays.asList(new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]))); + + Map> invalidTwoElementsSet = new HashMap<>(); + invalidTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1))); + + when(accounts.getAccountsForDevices(eq(validOneElementSet))).thenReturn(validOneAccount); + when(accounts.getAccountsForDevices(eq(validTwoElementsSet))).thenReturn(validTwoAccount); + when(accounts.getAccountsForDevices(eq(invalidTwoElementsSet))).thenThrow(new MissingDevicesException(new HashSet<>(Arrays.asList(EXISTS_NUMBER)))); + + addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class))); + } + + @Test + public void validRequestsTest() throws Exception { + MessageResponse result = client().resource("/v1/federation/message") + .entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER_2, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()))) + .type(MediaType.APPLICATION_JSON) + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN)) + .put(MessageResponse.class); + + assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER_2)); + assertThat(result.getFailure()).isEmpty(); + assertThat(result.getNumbersMissingDevices()).isEmpty(); + + result = client().resource("/v1/federation/message") + .entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()), + new RelayMessage(EXISTS_NUMBER, 2, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()))) + .type(MediaType.APPLICATION_JSON) + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN)) + .put(MessageResponse.class); + + assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER, EXISTS_NUMBER + "." + 2)); + assertThat(result.getFailure()).isEmpty(); + assertThat(result.getNumbersMissingDevices()).isEmpty(); + } + + @Test + public void invalidRequestTest() throws Exception { + MessageResponse result = client().resource("/v1/federation/message") + .entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()))) + .type(MediaType.APPLICATION_JSON) + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN)) + .put(MessageResponse.class); + + assertThat(result.getSuccess()).isEmpty(); + assertThat(result.getFailure()).isEqualTo(Arrays.asList(EXISTS_NUMBER)); + assertThat(result.getNumbersMissingDevices()).isEqualTo(new HashSet<>(Arrays.asList(EXISTS_NUMBER))); + } +} \ No newline at end of file