From d3830a7fd4831af7959800efc08bd2f01d42c2e1 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sat, 11 Jan 2014 12:43:07 -1000 Subject: [PATCH] Split Account into Device and Account definitions. --- .../textsecuregcm/WhisperServerService.java | 2 +- .../controllers/AccountController.java | 4 +- .../controllers/DeviceController.java | 2 +- .../controllers/FederationController.java | 47 +++---- .../controllers/KeysController.java | 7 +- .../controllers/MessageController.java | 120 +++++++++++------- .../textsecuregcm/push/PushSender.java | 33 +---- .../textsecuregcm/storage/Account.java | 89 +++++++++++++ .../textsecuregcm/storage/Accounts.java | 17 +-- .../storage/AccountsManager.java | 63 +++++++-- .../textsecuregcm/storage/Device.java | 12 +- .../textsecuregcm/storage/Keys.java | 4 +- .../workers/DirectoryUpdater.java | 2 +- .../controllers/AccountControllerTest.java | 5 +- .../controllers/DeviceControllerTest.java | 4 +- .../tests/controllers/KeyControllerTest.java | 19 +-- 16 files changed, 290 insertions(+), 140 deletions(-) create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/Account.java diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index e13159f8f..84c55492f 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -135,7 +135,7 @@ public class WhisperServerService extends Service { environment.addResource(new FederationController(keys, accountsManager, pushSender, urlSigner)); environment.addServlet(new MessageController(rateLimiters, deviceAuthenticator, - pushSender, federatedClientManager), + pushSender, accountsManager, federatedClientManager), MessageController.PATH); environment.addHealthCheck(new RedisHealthCheck(redisClient)); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 65aa88330..3e2d9fa63 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -31,6 +31,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; @@ -53,6 +54,7 @@ import javax.ws.rs.core.Response; import java.io.IOException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; +import java.util.Arrays; @Path("/v1/accounts") public class AccountController { @@ -142,7 +144,7 @@ public class AccountController { device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setDeviceId(0); - accounts.createResetNumber(device); + accounts.create(new Account(number, accountAttributes.getSupportsSms(), device)); pendingAccounts.remove(number); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 831201bee..7eeddffbd 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -112,7 +112,7 @@ public class DeviceController { device.setSupportsSms(accountAttributes.getSupportsSms()); device.setFetchesMessages(accountAttributes.getFetchesMessages()); - accounts.createAccountOnExistingNumber(device); + accounts.provisionDevice(device); pendingDevices.remove(number); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java index d3f11a9af..683d53a50 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/FederationController.java @@ -17,6 +17,7 @@ package org.whispersystems.textsecuregcm.controllers; import com.amazonaws.HttpMethod; +import com.google.common.base.Optional; import com.google.protobuf.InvalidProtocolBufferException; import com.yammer.dropwizard.auth.Auth; import com.yammer.metrics.annotation.Timed; @@ -32,6 +33,7 @@ import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Keys; @@ -93,15 +95,15 @@ public class FederationController { @Path("/key/{number}") @Produces(MediaType.APPLICATION_JSON) public UnstructuredPreKeyList getKey(@Auth FederatedPeer peer, - @PathParam("number") String number) + @PathParam("number") String number) { - UnstructuredPreKeyList preKeys = keys.get(number, accounts.getAllByNumber(number)); - - if (preKeys == null) { + Optional account = accounts.getAccount(number); + UnstructuredPreKeyList keyList = null; + if (account.isPresent()) + keyList = keys.get(number, account.get()); + if (!account.isPresent() || keyList.getKeys().isEmpty()) throw new WebApplicationException(Response.status(404).build()); - } - - return preKeys; + return keyList; } @Timed @@ -113,37 +115,38 @@ public class FederationController { throws IOException { try { - Map>> destinations = new HashMap<>(); - + Map> localDestinations = new HashMap<>(); for (RelayMessage message : messages) { - Pair> deviceIds = destinations.get(message.getDestination()); + Set deviceIds = localDestinations.get(message.getDestination()); if (deviceIds == null) { - deviceIds = new Pair>(true, new HashSet()); - destinations.put(message.getDestination(), deviceIds); + deviceIds = new HashSet<>(); + localDestinations.put(message.getDestination(), deviceIds); } - deviceIds.second().add(message.getDestinationDeviceId()); + deviceIds.add(message.getDestinationDeviceId()); } - Map, Device> accountCache = new HashMap<>(); - List numbersMissingDevices = new LinkedList<>(); - pushSender.fillLocalAccountsCache(destinations, accountCache, numbersMissingDevices); + Pair, List> accountsForDevices = accounts.getAccountsForDevices(localDestinations); - List success = new LinkedList<>(); - List failure = new LinkedList<>(numbersMissingDevices); + Map localAccounts = accountsForDevices.first(); + List numbersMissingDevices = accountsForDevices.second(); + List success = new LinkedList<>(); + List failure = new LinkedList<>(numbersMissingDevices); for (RelayMessage message : messages) { - Device device = accountCache.get(new Pair<>(message.getDestination(), message.getDestinationDeviceId())); - if (device == null) + Account destinationAccount = localAccounts.get(message.getDestination()); + if (destinationAccount == null) continue; + Device device = destinationAccount.getDevice(message.getDestinationDeviceId()); OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal()) .toBuilder() .setRelay(peer.getName()) .build(); try { pushSender.sendMessage(device, signal); + success.add(device.getBackwardsCompatibleNumberEncoding()); } catch (NoSuchUserException e) { logger.info("No such user", e); - failure.add(message.getDestination()); + failure.add(device.getBackwardsCompatibleNumberEncoding()); } } @@ -169,7 +172,7 @@ public class FederationController { public ClientContacts getUserTokens(@Auth FederatedPeer peer, @PathParam("offset") int offset) { - List numberList = accounts.getAllMasterAccounts(offset, ACCOUNT_CHUNK_SIZE); + List numberList = accounts.getAllMasterDevices(offset, ACCOUNT_CHUNK_SIZE); List clientContacts = new LinkedList<>(); for (Device device : numberList) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index b8612f545..e1bd430da 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Keys; @@ -78,7 +79,11 @@ public class KeysController { UnstructuredPreKeyList keyList; if (relay == null) { - keyList = keys.get(number, accountsManager.getAllByNumber(number)); + Optional account = accountsManager.getAccount(number); + if (account.isPresent()) + keyList = keys.get(number, account.get()); + else + throw new WebApplicationException(Response.status(404).build()); } else { keyList = federatedClientManager.getClient(relay).getKeys(number); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 9f4b62b88..8adfd5fd9 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -41,6 +41,8 @@ import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.IterablePair; @@ -52,10 +54,12 @@ import javax.servlet.AsyncContext; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.Path; import java.io.BufferedReader; import java.io.IOException; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -79,20 +83,34 @@ public class MessageController extends HttpServlet { private final FederatedClientManager federatedClientManager; private final ObjectMapper objectMapper; private final ExecutorService executor; + private final AccountsManager accountsManager; public MessageController(RateLimiters rateLimiters, DeviceAuthenticator deviceAuthenticator, PushSender pushSender, + AccountsManager accountsManager, FederatedClientManager federatedClientManager) { this.rateLimiters = rateLimiters; this.deviceAuthenticator = deviceAuthenticator; this.pushSender = pushSender; + this.accountsManager = accountsManager; this.federatedClientManager = federatedClientManager; this.objectMapper = new ObjectMapper(); this.executor = Executors.newFixedThreadPool(10); } + class LocalOrRemoteDevice { + Device device; + String relay, number; long deviceId; + LocalOrRemoteDevice(Device device) { + this.device = device; this.number = device.getNumber(); this.deviceId = device.getDeviceId(); + } + LocalOrRemoteDevice(String relay, String number, long deviceId) { + this.relay = relay; this.number = number; this.deviceId = deviceId; + } + } + @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) { TimerContext timerContext = timer.time(); @@ -103,20 +121,12 @@ public class MessageController extends HttpServlet { rateLimiters.getMessagesLimiter().validate(sender.getNumber()); - - Map, Device> deviceCache = new HashMap<>(); List numbersMissingDevices = new LinkedList<>(); - List incomingMessages = messages.getMessages(); - List outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), - incomingMessages, - deviceCache, - numbersMissingDevices); + List> outgoingMessages = + getOutgoingMessageSignals(sender.getNumber(), messages.getMessages(), numbersMissingDevices); - IterablePair listPair = new IterablePair<>(incomingMessages, - outgoingMessages); - - handleAsyncDelivery(timerContext, req.startAsync(), listPair, deviceCache, numbersMissingDevices); + handleAsyncDelivery(timerContext, req.startAsync(), outgoingMessages, numbersMissingDevices); } catch (AuthenticationException e) { failureMeter.mark(); timerContext.stop(); @@ -139,8 +149,7 @@ public class MessageController extends HttpServlet { private void handleAsyncDelivery(final TimerContext timerContext, final AsyncContext context, - final IterablePair listPair, - final Map, Device> deviceCache, + final List> listPair, final List numbersMissingDevices) { executor.submit(new Runnable() { @@ -151,38 +160,37 @@ public class MessageController extends HttpServlet { HttpServletResponse response = (HttpServletResponse) context.getResponse(); try { - Map>> relayMessages = new HashMap<>(); - for (Pair messagePair : listPair) { - String destination = messagePair.first().getDestination(); - long destinationDeviceId = messagePair.first().getDestinationDeviceId(); - String relay = messagePair.first().getRelay(); + Map>> relayMessages = new HashMap<>(); + for (Pair messagePair : listPair) { + String relay = messagePair.first().relay; if (Util.isEmpty(relay)) { + String encodedId = messagePair.first().device.getBackwardsCompatibleNumberEncoding(); try { - pushSender.sendMessage(deviceCache.get(new Pair<>(destination, destinationDeviceId)), messagePair.second()); + pushSender.sendMessage(messagePair.first().device, messagePair.second()); + success.add(encodedId); } catch (NoSuchUserException e) { logger.debug("No such user", e); - failure.add(destination); + failure.add(encodedId); } } else { - Set> messageSet = relayMessages.get(relay); + Set> messageSet = relayMessages.get(relay); if (messageSet == null) { messageSet = new HashSet<>(); relayMessages.put(relay, messageSet); } messageSet.add(messagePair); } - success.add(destination); } - for (Map.Entry>> messagesForRelay : relayMessages.entrySet()) { + for (Map.Entry>> messagesForRelay : relayMessages.entrySet()) { try { FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey()); List messages = new LinkedList<>(); - for (Pair message : messagesForRelay.getValue()) { - messages.add(new RelayMessage(message.first().getDestination(), - message.first().getDestinationDeviceId(), + for (Pair message : messagesForRelay.getValue()) { + messages.add(new RelayMessage(message.first().number, + message.first().deviceId, message.second().toByteArray())); } @@ -195,8 +203,8 @@ public class MessageController extends HttpServlet { numbersMissingDevices.add(string); } catch (NoSuchPeerException e) { logger.info("No such peer", e); - for (Pair messagePair : messagesForRelay.getValue()) - failure.add(messagePair.first().getDestination()); + for (Pair messagePair : messagesForRelay.getValue()) + failure.add(messagePair.first().number); } } @@ -210,6 +218,11 @@ public class MessageController extends HttpServlet { failureMeter.mark(); response.setStatus(501); context.complete(); + } catch (Exception e) { + logger.error("Unknown error sending message", e); + failureMeter.mark(); + response.setStatus(500); + context.complete(); } timerContext.stop(); @@ -217,28 +230,32 @@ public class MessageController extends HttpServlet { }); } - /** - * @param deviceCache is a map from Pair to the account - */ @Nullable - private List getOutgoingMessageSignals(String sourceNumber, - List incomingMessages, - Map, Device> deviceCache, - List numbersMissingDevices) + private List> getOutgoingMessageSignals(String sourceNumber, + List incomingMessages, + List numbersMissingDevices) { - List outgoingMessages = new LinkedList<>(); - // # local deviceIds - Map>> destinations = new HashMap<>(); + List> outgoingMessages = new LinkedList<>(); + Map> localDestinations = new HashMap<>(); + Set destinationNumbers = new HashSet<>(); for (IncomingMessage incoming : incomingMessages) { - Pair> deviceIds = destinations.get(incoming.getDestination()); + destinationNumbers.add(incoming.getDestination()); + if (!Util.isEmpty(incoming.getRelay())) + continue; + + Set deviceIds = localDestinations.get(incoming.getDestination()); if (deviceIds == null) { - deviceIds = new Pair>(Util.isEmpty(incoming.getRelay()), new HashSet()); - destinations.put(incoming.getDestination(), deviceIds); + deviceIds = new HashSet<>(); + localDestinations.put(incoming.getDestination(), deviceIds); } - deviceIds.second().add(incoming.getDestinationDeviceId()); + deviceIds.add(incoming.getDestinationDeviceId()); } - pushSender.fillLocalAccountsCache(destinations, deviceCache, numbersMissingDevices); + Pair, List> accountsForDevices = accountsManager.getAccountsForDevices(localDestinations); + + Map localAccounts = accountsForDevices.first(); + for (String number : accountsForDevices.second()) + numbersMissingDevices.add(number); for (IncomingMessage incoming : incomingMessages) { OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder(); @@ -255,13 +272,22 @@ public class MessageController extends HttpServlet { int index = 0; - for (String destination : destinations.keySet()) { - if (!destination.equals(incoming.getDestination())) { + for (String destination : destinationNumbers) { + if (!destination.equals(incoming.getDestination())) outgoingMessage.setDestinations(index++, destination); - } } - outgoingMessages.add(outgoingMessage.build()); + LocalOrRemoteDevice device = null; + if (!Util.isEmpty(incoming.getRelay())) + device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId()); + else { + Account destination = localAccounts.get(incoming.getDestination()); + if (destination != null) + device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId())); + } + + if (device != null) + outgoingMessages.add(new Pair<>(device, outgoingMessage.build())); } return outgoingMessages; diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index 0e11b0515..b9358d12c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -16,6 +16,7 @@ */ package org.whispersystems.textsecuregcm.push; +import com.google.common.base.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; @@ -23,6 +24,7 @@ import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager; @@ -61,35 +63,6 @@ public class PushSender { this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey()); } - /** - * For each local destination in destinations, either adds all its accounts to accountCache or adds the number to - * numbersMissingDevices, if the deviceIds list don't match what is required. - * @param destinations Map from number to Pair<localNumber, Set<deviceIds>> - * @param accountCache Map from <number, deviceId> to account - * @param numbersMissingDevices list of numbers missing devices - */ - public void fillLocalAccountsCache(Map>> destinations, Map, Device> accountCache, List numbersMissingDevices) { - for (Map.Entry>> destination : destinations.entrySet()) { - if (destination.getValue().first()) { - String number = destination.getKey(); - List deviceList = accounts.getAllByNumber(number); - Set deviceIdsIncluded = destination.getValue().second(); - if (deviceList.size() != deviceIdsIncluded.size()) - numbersMissingDevices.add(number); - else { - for (Device device : deviceList) { - if (!deviceIdsIncluded.contains(device.getDeviceId())) { - numbersMissingDevices.add(number); - break; - } - } - for (Device device : deviceList) - accountCache.put(new Pair<>(number, device.getDeviceId()), device); - } - } - } - } - public void sendMessage(Device device, MessageProtos.OutgoingMessageSignal outgoingMessage) throws IOException, NoSuchUserException { @@ -99,7 +72,7 @@ public class PushSender { if (device.getGcmRegistrationId() != null) sendGcmMessage(device, message); else if (device.getApnRegistrationId() != null) sendApnMessage(device, message); else if (device.getFetchesMessages()) storeFetchedMessage(device, message); - else throw new NoSuchUserException("No push identifier!"); + else throw new NoSuchUserException("No push identifier!"); } private void sendGcmMessage(Device device, EncryptedOutgoingMessage outgoingMessage) diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java new file mode 100644 index 000000000..eb40cada0 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -0,0 +1,89 @@ +/** + * Copyright (C) 2013 Open WhisperSystems + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package org.whispersystems.textsecuregcm.storage; + + +import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; +import org.whispersystems.textsecuregcm.util.Util; + +import java.io.Serializable; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class Account implements Serializable { + private String number; + private boolean supportsSms; + private Map devices = new HashMap<>(); + + private Account(String number, boolean supportsSms) { + this.number = number; + this.supportsSms = supportsSms; + } + + public Account(String number, boolean supportsSms, Device onlyDevice) { + this(number, supportsSms); + this.devices.put(onlyDevice.getDeviceId(), onlyDevice); + } + + public Account(String number, boolean supportsSms, List devices) { + this(number, supportsSms); + for (Device device : devices) + this.devices.put(device.getDeviceId(), device); + } + + public void setNumber(String number) { + this.number = number; + } + + public String getNumber() { + return number; + } + + public boolean getSupportsSms() { + return supportsSms; + } + + public void setSupportsSms(boolean supportsSms) { + this.supportsSms = supportsSms; + } + + public boolean isActive() { + Device masterDevice = devices.get((long) 1); + return masterDevice != null && masterDevice.isActive(); + } + + public Collection getDevices() { + return devices.values(); + } + + public Device getDevice(long destinationDeviceId) { + return devices.get(destinationDeviceId); + } + + public boolean hasAllDeviceIds(Set deviceIds) { + if (devices.size() != deviceIds.size()) + return false; + for (long deviceId : devices.keySet()) { + if (!deviceIds.contains(deviceId)) + return false; + } + return true; + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 2165a179a..7aecbe040 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -37,6 +37,8 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; @@ -78,22 +80,22 @@ public abstract class Accounts { "WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id") abstract void update(@AccountBinder Device device); - @Mapper(AccountMapper.class) + @Mapper(DeviceMapper.class) @SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id") abstract Device get(@Bind("number") String number, @Bind("device_id") long deviceId); @SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts") abstract long getNumberCount(); - @Mapper(AccountMapper.class) + @Mapper(DeviceMapper.class) @SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1 OFFSET :offset LIMIT :limit") - abstract List getAllFirstAccounts(@Bind("offset") int offset, @Bind("limit") int length); + abstract List getAllMasterDevices(@Bind("offset") int offset, @Bind("limit") int length); - @Mapper(AccountMapper.class) + @Mapper(DeviceMapper.class) @SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1") - public abstract Iterator getAllFirstAccounts(); + public abstract Iterator getAllMasterDevices(); - @Mapper(AccountMapper.class) + @Mapper(DeviceMapper.class) @SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number") public abstract List getAllByNumber(@Bind("number") String number); @@ -104,8 +106,7 @@ public abstract class Accounts { return insertStep(device); } - public static class AccountMapper implements ResultSetMapper { - + public static class DeviceMapper implements ResultSetMapper { @Override public Device map(int i, ResultSet resultSet, StatementContext statementContext) throws SQLException diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 0b1ed9063..e0fa2e484 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -20,10 +20,17 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.base.Optional; import net.spy.memcached.MemcachedClient; import org.whispersystems.textsecuregcm.entities.ClientContact; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; +import sun.util.logging.resources.logging_zh_CN; +import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedList; import java.util.List; +import java.util.Map; +import java.util.Set; public class AccountsManager { @@ -44,16 +51,17 @@ public class AccountsManager { return accounts.getNumberCount(); } - public List getAllMasterAccounts(int offset, int length) { - return accounts.getAllFirstAccounts(offset, length); + public List getAllMasterDevices(int offset, int length) { + return accounts.getAllMasterDevices(offset, length); } - public Iterator getAllMasterAccounts() { - return accounts.getAllFirstAccounts(); + public Iterator getAllMasterDevices() { + return accounts.getAllMasterDevices(); } - /** Creates a new Device and NumberData, clearing all existing accounts/data on the given number */ - public void createResetNumber(Device device) { + /** Creates a new Account (WITH ONE DEVICE), clearing all existing devices on the given number */ + public void create(Account account) { + Device device = account.getDevices().iterator().next(); long id = accounts.insertClearingNumber(device); device.setId(id); @@ -64,8 +72,8 @@ public class AccountsManager { updateDirectory(device); } - /** Creates a new Device for an existing NumberData (setting the deviceId) */ - public void createAccountOnExistingNumber(Device device) { + /** Creates a new Device for an existing Account */ + public void provisionDevice(Device device) { long id = accounts.insert(device); device.setId(id); @@ -104,8 +112,43 @@ public class AccountsManager { else return Optional.absent(); } - public List getAllByNumber(String number) { - return accounts.getAllByNumber(number); + public Optional getAccount(String number) { + List devices = accounts.getAllByNumber(number); + if (devices.isEmpty()) + return Optional.absent(); + return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices)); + } + + private Map getAllAccounts(Set numbers) { + //TODO: ONE QUERY + Map result = new HashMap<>(); + for (String number : numbers) { + Optional account = getAccount(number); + if (account.isPresent()) + result.put(number, account.get()); + } + return result; + } + + public Pair, List> getAccountsForDevices(Map> destinations) { + List numbersMissingDevices = new LinkedList<>(); + Map localAccounts = getAllAccounts(destinations.keySet()); + + for (String number : destinations.keySet()) { + if (localAccounts.get(number) == null) + numbersMissingDevices.add(number); + } + + Iterator localAccountIterator = localAccounts.values().iterator(); + while (localAccountIterator.hasNext()) { + Account account = localAccountIterator.next(); + if (!account.hasAllDeviceIds(destinations.get(account.getNumber()))) { + numbersMissingDevices.add(account.getNumber()); + localAccountIterator.remove(); + } + } + + return new Pair<>(localAccounts, numbersMissingDevices); } private void updateDirectory(Device device) { diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 45412c171..d053045a1 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -125,15 +125,19 @@ public class Device implements Serializable { this.id = id; } - public void setFetchesMessages(boolean fetchesMessages) { - this.fetchesMessages = fetchesMessages; + public boolean isActive() { + return fetchesMessages || !Util.isEmpty(getApnRegistrationId()) || !Util.isEmpty(getGcmRegistrationId()); } public boolean getFetchesMessages() { return fetchesMessages; } - public boolean isActive() { - return getFetchesMessages() || !Util.isEmpty(getApnRegistrationId()) || !Util.isEmpty(getGcmRegistrationId()); + public void setFetchesMessages(boolean fetchesMessages) { + this.fetchesMessages = fetchesMessages; + } + + public String getBackwardsCompatibleNumberEncoding() { + return deviceId == 1 ? number : (number + "." + deviceId); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index a4d13475e..c7a7cf7a2 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -79,9 +79,9 @@ public abstract class Keys { } @Transaction(TransactionIsolationLevel.SERIALIZABLE) - public UnstructuredPreKeyList get(String number, List devices) { + public UnstructuredPreKeyList get(String number, Account account) { List preKeys = new LinkedList<>(); - for (Device device : devices) { + for (Device device : account.getDevices()) { PreKey preKey = retrieveFirst(number, device.getDeviceId()); if (preKey != null) preKeys.add(preKey); diff --git a/src/main/java/org/whispersystems/textsecuregcm/workers/DirectoryUpdater.java b/src/main/java/org/whispersystems/textsecuregcm/workers/DirectoryUpdater.java index f848f647e..38bd4a1a7 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/workers/DirectoryUpdater.java +++ b/src/main/java/org/whispersystems/textsecuregcm/workers/DirectoryUpdater.java @@ -53,7 +53,7 @@ public class DirectoryUpdater { BatchOperationHandle batchOperation = directory.startBatchOperation(); try { - Iterator accounts = accountsManager.getAllMasterAccounts(); + Iterator accounts = accountsManager.getAllMasterDevices(); if (accounts == null) return; diff --git a/src/test/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/src/test/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index 349d5d919..547cc49f8 100644 --- a/src/test/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/src/test/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -14,6 +14,7 @@ import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.sms.SmsSender; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; @@ -80,7 +81,7 @@ public class AccountControllerTest extends ResourceTest { ((Device)invocation.getArguments()[0]).setDeviceId(2); return null; } - }).when(accountsManager).createAccountOnExistingNumber(any(Device.class)); + }).when(accountsManager).provisionDevice(any(Device.class)); addResource(new DumbVerificationAccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender)); } @@ -107,7 +108,7 @@ public class AccountControllerTest extends ResourceTest { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager).createResetNumber(isA(Device.class)); + verify(accountsManager).create(isA(Account.class)); ArgumentCaptor number = ArgumentCaptor.forClass(String.class); verify(pendingAccountsManager).remove(number.capture()); diff --git a/src/test/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/src/test/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 32e285afb..2d306ef9d 100644 --- a/src/test/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/src/test/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -78,7 +78,7 @@ public class DeviceControllerTest extends ResourceTest { ((Device) invocation.getArguments()[0]).setDeviceId(2); return null; } - }).when(accountsManager).createAccountOnExistingNumber(any(Device.class)); + }).when(accountsManager).provisionDevice(any(Device.class)); addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, rateLimiters)); } @@ -99,7 +99,7 @@ public class DeviceControllerTest extends ResourceTest { assertThat(deviceId).isNotEqualTo(AuthHelper.DEFAULT_DEVICE_ID); ArgumentCaptor newAccount = ArgumentCaptor.forClass(Device.class); - verify(accountsManager).createAccountOnExistingNumber(newAccount.capture()); + verify(accountsManager).provisionDevice(newAccount.capture()); assertThat(deviceId).isEqualTo(newAccount.getValue().getDeviceId()); ArgumentCaptor number = ArgumentCaptor.forClass(String.class); diff --git a/src/test/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java b/src/test/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java index 88a3be63c..cfe66edcf 100644 --- a/src/test/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java +++ b/src/test/org/whispersystems/textsecuregcm/tests/controllers/KeyControllerTest.java @@ -1,5 +1,6 @@ package org.whispersystems.textsecuregcm.tests.controllers; +import com.google.common.base.Optional; import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.GenericType; import com.yammer.dropwizard.testing.ResourceTest; @@ -9,6 +10,7 @@ import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Keys; @@ -31,6 +33,7 @@ public class KeyControllerTest extends ResourceTest { private final Keys keys = mock(Keys.class); Device[] fakeDevice; + Account existsAccount; @Override protected void setUpResources() { @@ -43,17 +46,18 @@ public class KeyControllerTest extends ResourceTest { fakeDevice = new Device[2]; fakeDevice[0] = mock(Device.class); fakeDevice[1] = mock(Device.class); + existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - when(keys.get(eq(EXISTS_NUMBER), anyList())).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2))); - when(keys.get(eq(NOT_EXISTS_NUMBER), anyList())).thenReturn(null); + when(keys.get(eq(EXISTS_NUMBER), isA(Account.class))).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2))); + when(keys.get(eq(NOT_EXISTS_NUMBER), isA(Account.class))).thenReturn(null); when(fakeDevice[0].getDeviceId()).thenReturn(AuthHelper.DEFAULT_DEVICE_ID); when(fakeDevice[1].getDeviceId()).thenReturn((long) 2); - when(accounts.getAllByNumber(EXISTS_NUMBER)).thenReturn(Arrays.asList(fakeDevice[0], fakeDevice[1])); - when(accounts.getAllByNumber(NOT_EXISTS_NUMBER)).thenReturn(new LinkedList()); + when(accounts.getAccount(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount)); + when(accounts.getAccount(NOT_EXISTS_NUMBER)).thenReturn(Optional.absent()); addResource(new KeysController(rateLimiters, keys, accounts, null)); } @@ -71,7 +75,7 @@ public class KeyControllerTest extends ResourceTest { assertThat(result.getId() == 0); assertThat(result.getNumber() == null); - verify(keys).get(eq(EXISTS_NUMBER), eq(Arrays.asList(fakeDevice))); + verify(keys).get(eq(EXISTS_NUMBER), eq(existsAccount)); verifyNoMoreInteractions(keys); List results = client().resource(String.format("/v1/keys/%s?multikeys", EXISTS_NUMBER)) @@ -95,7 +99,7 @@ public class KeyControllerTest extends ResourceTest { assertThat(result.getId() == 1); assertThat(result.getNumber() == null); - verify(keys, times(2)).get(eq(EXISTS_NUMBER), eq(Arrays.asList(fakeDevice[0], fakeDevice[1]))); + verify(keys, times(2)).get(eq(EXISTS_NUMBER), eq(existsAccount)); verifyNoMoreInteractions(keys); } @@ -106,8 +110,7 @@ public class KeyControllerTest extends ResourceTest { .get(ClientResponse.class); assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404); - - verify(keys).get(NOT_EXISTS_NUMBER, new LinkedList()); + verifyNoMoreInteractions(keys); } @Test