Initial multi device support refactoring.

1) Store account data as a json type, which includes all
   devices in a single object.

2) Simplify message delivery logic.

3) Make federated calls a pass through to standard controllers.

4) Simplify key retrieval logic.
This commit is contained in:
Moxie Marlinspike 2014-01-18 23:45:07 -08:00
parent 6f9226dcf9
commit 74f71fd8a6
47 changed files with 961 additions and 1211 deletions

View File

@ -108,12 +108,6 @@
<artifactId>jersey-json</artifactId>
<version>1.17.1</version>
</dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>stringtemplate</artifactId>
<version>3.2.1</version>
</dependency>
</dependencies>
<build>

View File

@ -27,7 +27,7 @@ import com.yammer.metrics.reporting.GraphiteReporter;
import net.spy.memcached.MemcachedClient;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration;
@ -58,7 +58,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDeviceRegistrations;
import org.whispersystems.textsecuregcm.storage.PendingDevices;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
@ -96,11 +96,11 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
DBIFactory dbiFactory = new DBIFactory();
DBI jdbi = dbiFactory.build(environment, config.getDatabaseConfiguration(), "postgresql");
Accounts accounts = jdbi.onDemand(Accounts.class);
PendingAccounts pendingAccounts = jdbi.onDemand(PendingAccounts.class);
PendingDeviceRegistrations pendingDevices = jdbi.onDemand(PendingDeviceRegistrations.class);
Keys keys = jdbi.onDemand(Keys.class);
StoredMessages storedMessages = jdbi.onDemand(StoredMessages.class);
Accounts accounts = jdbi.onDemand(Accounts.class);
PendingAccounts pendingAccounts = jdbi.onDemand(PendingAccounts.class);
PendingDevices pendingDevices = jdbi.onDemand(PendingDevices.class);
Keys keys = jdbi.onDemand(Keys.class);
StoredMessages storedMessages = jdbi.onDemand(StoredMessages.class );
MemcachedClient memcachedClient = new MemcachedClientFactory(config.getMemcacheConfiguration()).getClient();
JedisPool redisClient = new RedisClientFactory(config.getRedisConfiguration()).getRedisClientPool();
@ -109,10 +109,12 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, memcachedClient);
PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, memcachedClient);
AccountsManager accountsManager = new AccountsManager(accounts, directory, memcachedClient);
DeviceAuthenticator deviceAuthenticator = new DeviceAuthenticator(accountsManager );
FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration());
StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());
@ -120,7 +122,11 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
PushSender pushSender = new PushSender(config.getGcmConfiguration(),
config.getApnConfiguration(),
storedMessageManager,
accountsManager, directory);
accountsManager);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
environment.addProvider(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),
FederatedPeer.class,
@ -130,13 +136,10 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
environment.addResource(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
environment.addResource(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
environment.addResource(new DirectoryController(rateLimiters, directory));
environment.addResource(new AttachmentController(rateLimiters, federatedClientManager, urlSigner));
environment.addResource(new KeysController(rateLimiters, keys, accountsManager, federatedClientManager));
environment.addResource(new FederationController(keys, accountsManager, pushSender, urlSigner));
environment.addServlet(new MessageController(rateLimiters, deviceAuthenticator,
pushSender, accountsManager, federatedClientManager),
MessageController.PATH);
environment.addResource(new FederationController(accountsManager, attachmentController, keysController, messageController));
environment.addResource(attachmentController);
environment.addResource(keysController);
environment.addResource(messageController);
environment.addHealthCheck(new RedisHealthCheck(redisClient));
environment.addHealthCheck(new MemcacheHealthCheck(memcachedClient));

View File

@ -24,51 +24,58 @@ import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import java.util.concurrent.TimeUnit;
public class DeviceAuthenticator implements Authenticator<BasicCredentials, Device> {
public class AccountAuthenticator implements Authenticator<BasicCredentials, Account> {
private final Meter authenticationFailedMeter = Metrics.newMeter(DeviceAuthenticator.class,
private final Meter authenticationFailedMeter = Metrics.newMeter(AccountAuthenticator.class,
"authentication", "failed",
TimeUnit.MINUTES);
private final Meter authenticationSucceededMeter = Metrics.newMeter(DeviceAuthenticator.class,
private final Meter authenticationSucceededMeter = Metrics.newMeter(AccountAuthenticator.class,
"authentication", "succeeded",
TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(DeviceAuthenticator.class);
private final Logger logger = LoggerFactory.getLogger(AccountAuthenticator.class);
private final AccountsManager accountsManager;
public DeviceAuthenticator(AccountsManager accountsManager) {
public AccountAuthenticator(AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public Optional<Device> authenticate(BasicCredentials basicCredentials)
public Optional<Account> authenticate(BasicCredentials basicCredentials)
throws AuthenticationException
{
AuthorizationHeader authorizationHeader;
try {
authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getNumber());
if (!account.isPresent()) {
return Optional.absent();
}
Optional<Device> device = account.get().getDevice(authorizationHeader.getDeviceId());
if (!device.isPresent()) {
return Optional.absent();
}
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
authenticationSucceededMeter.mark();
account.get().setAuthenticatedDevice(device.get());
return account;
}
authenticationFailedMeter.mark();
return Optional.absent();
} catch (InvalidAuthorizationHeaderException iahe) {
return Optional.absent();
}
Optional<Device> device = accountsManager.get(authorizationHeader.getNumber(), authorizationHeader.getDeviceId());
if (!device.isPresent()) {
return Optional.absent();
}
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
authenticationSucceededMeter.mark();
return device;
}
authenticationFailedMeter.mark();
return Optional.absent();
}
}

View File

@ -32,8 +32,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -54,7 +54,6 @@ import javax.ws.rs.core.Response;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
@Path("/v1/accounts")
public class AccountController {
@ -97,7 +96,7 @@ public class AccountController {
rateLimiters.getVoiceDestinationLimiter().validate(number);
break;
default:
throw new WebApplicationException(Response.status(415).build());
throw new WebApplicationException(Response.status(422).build());
}
VerificationCode verificationCode = generateVerificationCode();
@ -137,14 +136,17 @@ public class AccountController {
}
Device device = new Device();
device.setNumber(number);
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setDeviceId(0);
accounts.create(new Account(number, accountAttributes.getSupportsSms(), device));
Account account = new Account();
account.setNumber(number);
account.setSupportsSms(accountAttributes.getSupportsSms());
account.addDevice(device);
accounts.create(account);
pendingAccounts.remove(number);
@ -161,36 +163,40 @@ public class AccountController {
@PUT
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Auth Device device, @Valid GcmRegistrationId registrationId) {
device.setApnRegistrationId(null);
device.setGcmRegistrationId(registrationId.getGcmRegistrationId());
accounts.update(device);
public void setGcmRegistrationId(@Auth Account account, @Valid GcmRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
accounts.update(account);
}
@Timed
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Auth Device device) {
device.setGcmRegistrationId(null);
accounts.update(device);
public void deleteGcmRegistrationId(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
device.setGcmId(null);
accounts.update(account);
}
@Timed
@PUT
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Auth Device device, @Valid ApnRegistrationId registrationId) {
device.setApnRegistrationId(registrationId.getApnRegistrationId());
device.setGcmRegistrationId(null);
accounts.update(device);
public void setApnRegistrationId(@Auth Account account, @Valid ApnRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(registrationId.getApnRegistrationId());
device.setGcmId(null);
accounts.update(account);
}
@Timed
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Auth Device device) {
device.setApnRegistrationId(null);
accounts.update(device);
public void deleteApnRegistrationId(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
accounts.update(account);
}
@Timed

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.google.common.base.Optional;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
@ -26,7 +27,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.UrlSigner;
@ -35,6 +36,7 @@ import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
@ -64,37 +66,38 @@ public class AttachmentController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public Response allocateAttachment(@Auth Device device) throws RateLimitExceededException {
rateLimiters.getAttachmentLimiter().validate(device.getNumber());
public AttachmentDescriptor allocateAttachment(@Auth Account account)
throws RateLimitExceededException
{
if (account.isRateLimited()) {
rateLimiters.getAttachmentLimiter().validate(account.getNumber());
}
long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT);
AttachmentDescriptor descriptor = new AttachmentDescriptor(attachmentId, url.toExternalForm());
long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT);
return new AttachmentDescriptor(attachmentId, url.toExternalForm());
return Response.ok().entity(descriptor).build();
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{attachmentId}")
public Response redirectToAttachment(@Auth Device device,
@PathParam("attachmentId") long attachmentId,
@QueryParam("relay") String relay)
public AttachmentUri redirectToAttachment(@Auth Account account,
@PathParam("attachmentId") long attachmentId,
@QueryParam("relay") Optional<String> relay)
throws IOException
{
try {
URL url;
if (relay == null) url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET);
else url = federatedClientManager.getClient(relay).getSignedAttachmentUri(attachmentId);
return Response.ok().entity(new AttachmentUri(url)).build();
} catch (IOException e) {
logger.warn("No conectivity", e);
return Response.status(500).build();
if (!relay.isPresent()) {
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET));
} else {
return new AttachmentUri(federatedClientManager.getClient(relay.get()).getSignedAttachmentUri(attachmentId));
}
} catch (NoSuchPeerException e) {
logger.info("No such peer: " + relay);
return Response.status(404).build();
throw new WebApplicationException(Response.status(404).build());
}
}

View File

@ -26,7 +26,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
@ -68,13 +70,13 @@ public class DeviceController {
@GET
@Path("/provisioning_code")
@Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Device device)
public VerificationCode createDeviceToken(@Auth Account account)
throws RateLimitExceededException
{
rateLimiters.getVerifyLimiter().validate(device.getNumber()); //TODO: New limiter?
rateLimiters.getVerifyLimiter().validate(account.getNumber()); //TODO: New limiter?
VerificationCode verificationCode = generateVerificationCode();
pendingDevices.store(device.getNumber(), verificationCode.getVerificationCode());
pendingDevices.store(account.getNumber(), verificationCode.getVerificationCode());
return verificationCode;
}
@ -84,12 +86,11 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Path("/{verification_code}")
public long verifyDeviceToken(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes)
public DeviceResponse verifyDeviceToken(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes)
throws RateLimitExceededException
{
Device device;
try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber();
@ -105,24 +106,28 @@ public class DeviceController {
throw new WebApplicationException(Response.status(403).build());
}
device = new Device();
device.setNumber(number);
Optional<Account> account = accounts.get(number);
if (!account.isPresent()) {
throw new WebApplicationException(Response.status(403).build());
}
Device device = new Device();
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
accounts.provisionDevice(device);
account.get().addDevice(device);
accounts.update(account.get());
pendingDevices.remove(number);
logger.debug("Stored new device device...");
return new DeviceResponse(device.getId());
} catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad Authorization Header", e);
throw new WebApplicationException(Response.status(401).build());
}
return device.getDeviceId();
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@ -21,11 +21,11 @@ import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContactTokens;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.util.Base64;
@ -60,10 +60,10 @@ public class DirectoryController {
@GET
@Path("/{token}")
@Produces(MediaType.APPLICATION_JSON)
public Response getTokenPresence(@Auth Device device, @PathParam("token") String token)
public Response getTokenPresence(@Auth Account account, @PathParam("token") String token)
throws RateLimitExceededException
{
rateLimiters.getContactsLimiter().validate(device.getNumber());
rateLimiters.getContactsLimiter().validate(account.getNumber());
try {
Optional<ClientContact> contact = directory.get(Base64.decodeWithoutPadding(token));
@ -82,10 +82,10 @@ public class DirectoryController {
@Path("/tokens")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public ClientContacts getContactIntersection(@Auth Device device, @Valid ClientContactTokens contacts)
public ClientContacts getContactIntersection(@Auth Account account, @Valid ClientContactTokens contacts)
throws RateLimitExceededException
{
rateLimiters.getContactsLimiter().validate(device.getNumber(), contacts.getContacts().size());
rateLimiters.getContactsLimiter().validate(account.getNumber(), contacts.getContacts().size());
try {
List<byte[]> tokens = new LinkedList<>();

View File

@ -16,9 +16,7 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.google.common.base.Optional;
import com.google.protobuf.InvalidProtocolBufferException;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
@ -27,40 +25,25 @@ import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.google.common.base.Preconditions.checkState;
@Path("/v1/federation")
public class FederationController {
@ -69,16 +52,20 @@ public class FederationController {
private static final int ACCOUNT_CHUNK_SIZE = 10000;
private final PushSender pushSender;
private final Keys keys;
private final AccountsManager accounts;
private final UrlSigner urlSigner;
private final AccountsManager accounts;
private final AttachmentController attachmentController;
private final KeysController keysController;
private final MessageController messageController;
public FederationController(Keys keys, AccountsManager accounts, PushSender pushSender, UrlSigner urlSigner) {
this.keys = keys;
this.accounts = accounts;
this.pushSender = pushSender;
this.urlSigner = urlSigner;
public FederationController(AccountsManager accounts,
AttachmentController attachmentController,
KeysController keysController,
MessageController messageController)
{
this.accounts = accounts;
this.attachmentController = attachmentController;
this.keysController = keysController;
this.messageController = messageController;
}
@Timed
@ -87,82 +74,61 @@ public class FederationController {
@Produces(MediaType.APPLICATION_JSON)
public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer,
@PathParam("attachmentId") long attachmentId)
throws IOException
{
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET);
return new AttachmentUri(url);
return attachmentController.redirectToAttachment(new NonLimitedAccount("Unknown", peer.getName()),
attachmentId, Optional.<String>absent());
}
@Timed
@GET
@Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
public PreKey getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
throws IOException
{
Optional<Account> account = accounts.getAccount(number);
UnstructuredPreKeyList keyList = null;
if (account.isPresent())
keyList = keys.get(number, account.get());
if (!account.isPresent() || keyList.getKeys().isEmpty())
throw new WebApplicationException(Response.status(404).build());
return keyList;
try {
return keysController.get(new NonLimitedAccount("Unknown", peer.getName()), number, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKeys(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysController.getDeviceKey(new NonLimitedAccount("Unknown", peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@PUT
@Path("/message")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse relayMessage(@Auth FederatedPeer peer, @Valid List<RelayMessage> messages)
@Path("/messages/{source}/{destination}")
public void sendMessages(@Auth FederatedPeer peer,
@PathParam("source") String source,
@PathParam("destination") String destination,
@Valid IncomingMessageList messages)
throws IOException
{
try {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (RelayMessage message : messages) {
Set<Long> deviceIds = localDestinations.get(message.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(message.getDestination(), deviceIds);
}
deviceIds.add(message.getDestinationDeviceId());
}
List<Account> localAccounts = null;
try {
localAccounts = accounts.getAccountsForDevices(localDestinations);
} catch (MissingDevicesException e) {
return new MessageResponse(e.missingNumbers);
}
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
for (RelayMessage message : messages) {
Account destinationAccount = null;
for (Account account : localAccounts)
if (account.getNumber().equals(message.getDestination()))
destinationAccount= account;
checkState(destinationAccount != null);
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
.toBuilder()
.setRelay(peer.getName())
.build();
try {
pushSender.sendMessage(device, signal);
success.add(device.getBackwardsCompatibleNumberEncoding());
} catch (NoSuchUserException e) {
logger.info("No such user", e);
failure.add(device.getBackwardsCompatibleNumberEncoding());
}
}
return new MessageResponse(success, failure);
} catch (InvalidProtocolBufferException ipe) {
logger.warn("ProtoBuf", ipe);
throw new WebApplicationException(Response.status(400).build());
messages.setRelay(null);
messageController.sendMessage(new NonLimitedAccount(source, peer.getName()), destination, messages);
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@ -181,15 +147,16 @@ public class FederationController {
public ClientContacts getUserTokens(@Auth FederatedPeer peer,
@PathParam("offset") int offset)
{
List<Device> numberList = accounts.getAllMasterDevices(offset, ACCOUNT_CHUNK_SIZE);
List<Account> accountList = accounts.getAll(offset, ACCOUNT_CHUNK_SIZE);
List<ClientContact> clientContacts = new LinkedList<>();
for (Device device : numberList) {
byte[] token = Util.getContactToken(device.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms());
for (Account account : accountList) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
if (!device.isActive())
if (!account.isActive()) {
clientContact.setInactive(true);
}
clientContacts.add(clientContact);
}

View File

@ -29,7 +29,6 @@ import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
@ -43,7 +42,6 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.List;
@Path("/v1/keys")
public class KeysController {
@ -52,46 +50,47 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final AccountsManager accountsManager;
private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accountsManager,
public KeysController(RateLimiters rateLimiters, Keys keys,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager;
}
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Device device, @Valid PreKeyList preKeys) {
keys.store(device.getNumber(), device.getDeviceId(), preKeys.getLastResortKey(), preKeys.getKeys());
public void setKeys(@Auth Account account, @Valid PreKeyList preKeys) {
Device device = account.getAuthenticatedDevice().get();
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
}
private List<PreKey> getKeys(Device device, String number, String relay) throws RateLimitExceededException
@Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
rateLimiters.getPreKeysLimiter().validate(device.getNumber() + "__" + number);
try {
UnstructuredPreKeyList keyList;
if (relay == null) {
Optional<Account> account = accountsManager.getAccount(number);
if (account.isPresent())
keyList = keys.get(number, account.get());
else
throw new WebApplicationException(Response.status(404).build());
} else {
keyList = federatedClientManager.getClient(relay).getKeys(number);
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (keyList == null || keyList.getKeys().isEmpty()) throw new WebApplicationException(Response.status(404).build());
else return keyList.getKeys();
Optional<UnstructuredPreKeyList> results;
if (!relay.isPresent()) results = getLocalKeys(number, deviceId);
else results = federatedClientManager.getClient(relay.get()).getKeys(number, deviceId);
if (results.isPresent()) return results.get();
else throw new WebApplicationException(Response.status(404).build());
} catch (NoSuchPeerException e) {
logger.info("No peer: " + relay);
throw new WebApplicationException(Response.status(404).build());
}
}
@ -100,15 +99,27 @@ public class KeysController {
@GET
@Path("/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Response get(@Auth Device device,
@PathParam("number") String number,
@QueryParam("multikeys") Optional<String> multikey,
@QueryParam("relay") String relay)
public PreKey get(@Auth Account account,
@PathParam("number") String number,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
if (!multikey.isPresent())
return Response.ok(getKeys(device, number, relay).get(0)).type(MediaType.APPLICATION_JSON).build();
else
return Response.ok(getKeys(device, number, relay)).type(MediaType.APPLICATION_JSON).build();
UnstructuredPreKeyList results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
return results.getKeys().get(0);
}
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceId) {
try {
if (deviceId.equals("*")) {
return keys.get(number);
}
Optional<PreKey> targetKey = keys.get(number, Long.parseLong(deviceId));
if (targetKey.isPresent()) return Optional.of(new UnstructuredPreKeyList(targetKey.get()));
else return Optional.absent();
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
}

View File

@ -16,383 +16,226 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.google.protobuf.ByteString;
import com.yammer.dropwizard.auth.AuthenticationException;
import com.yammer.dropwizard.auth.basic.BasicCredentials;
import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import com.yammer.metrics.core.Timer;
import com.yammer.metrics.core.TimerContext;
import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.MissingDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class MessageController extends HttpServlet {
@Path("/v1/messages")
public class MessageController {
public static final String PATH = "/v1/messages/";
private final Meter successMeter = Metrics.newMeter(MessageController.class, "deliver_message", "success", TimeUnit.MINUTES);
private final Meter failureMeter = Metrics.newMeter(MessageController.class, "deliver_message", "failure", TimeUnit.MINUTES);
private final Timer timer = Metrics.newTimer(MessageController.class, "deliver_message_time", TimeUnit.MILLISECONDS, TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters;
private final DeviceAuthenticator deviceAuthenticator;
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
private final ObjectMapper objectMapper;
private final ExecutorService executor;
private final AccountsManager accountsManager;
public MessageController(RateLimiters rateLimiters,
DeviceAuthenticator deviceAuthenticator,
PushSender pushSender,
AccountsManager accountsManager,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.deviceAuthenticator = deviceAuthenticator;
this.pushSender = pushSender;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager;
this.objectMapper = new ObjectMapper();
this.executor = Executors.newFixedThreadPool(10);
}
class LocalOrRemoteDevice {
Device device;
String relay, number; long deviceId;
LocalOrRemoteDevice(Device device) {
this.device = device; this.number = device.getNumber(); this.deviceId = device.getDeviceId();
}
LocalOrRemoteDevice(String relay, String number, long deviceId) {
this.relay = relay; this.number = number; this.deviceId = deviceId;
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
TimerContext timerContext = timer.time();
@Timed
@Path("/{destination}")
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void sendMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
try {
Device sender = authenticate(req);
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(401);
} catch (ValidationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(415);
} catch (IOException e) {
logger.warn("IOE", e);
failureMeter.mark();
timerContext.stop();
resp.setStatus(501);
} catch (RateLimitExceededException e) {
timerContext.stop();
failureMeter.mark();
resp.setStatus(413);
if (messages.getRelay() != null) sendLocalMessage(source, destinationName, messages);
else sendRelayMessage(source, destinationName, messages);
} catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
} catch (MissingDevicesException e) {
throw new WebApplicationException(Response.status(409)
.entity(new MissingDevices(e.getMissingDevices()))
.build());
}
}
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final Device sender,
final IncomingMessageList messages)
@Timed
@Path("/")
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse sendMessageLegacy(@Auth Account source, @Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
executor.submit(new Runnable() {
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try {
List<IncomingMessage> incomingMessages = messages.getMessages();
validateLegacyDestinations(incomingMessages);
try {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages;
try {
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
} catch (MissingDevicesException e) {
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
failureMeter.mark();
timerContext.stop();
return;
}
messages.setRelay(incomingMessages.get(0).getRelay());
sendMessage(source, incomingMessages.get(0).getDestination(), messages);
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
String relay = messagePair.first().relay;
if (Util.isEmpty(relay)) {
String encodedId = messagePair.first().device.getBackwardsCompatibleNumberEncoding();
try {
pushSender.sendMessage(messagePair.first().device, messagePair.second());
success.add(encodedId);
} catch (NoSuchUserException e) {
logger.debug("No such user", e);
failure.add(encodedId);
}
} else {
Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> messageSet = relayMessages.get(relay);
if (messageSet == null) {
messageSet = new HashSet<>();
relayMessages.put(relay, messageSet);
}
messageSet.add(messagePair);
}
}
for (Map.Entry<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> messagesForRelay : relayMessages.entrySet()) {
try {
FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey());
List<RelayMessage> messages = new LinkedList<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> message : messagesForRelay.getValue()) {
messages.add(new RelayMessage(message.first().number,
message.first().deviceId,
message.second().toByteArray()));
}
MessageResponse relayResponse = client.sendMessages(messages);
for (String string : relayResponse.getSuccess())
success.add(string);
for (String string : relayResponse.getFailure())
failure.add(string);
} catch (NoSuchPeerException e) {
logger.info("No such peer", e);
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
failure.add(messagePair.first().number);
}
}
byte[] responseData = serializeResponse(new MessageResponse(success, failure));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
successMeter.mark();
} catch (IOException e) {
logger.warn("Async Handler", e);
failureMeter.mark();
response.setStatus(501);
context.complete();
} catch (Exception e) {
logger.error("Unknown error sending message", e);
failureMeter.mark();
response.setStatus(500);
context.complete();
}
timerContext.stop();
}
});
return new MessageResponse(new LinkedList<String>(), new LinkedList<String>());
} catch (ValidationException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
@Nullable
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages)
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws NoSuchUserException, MissingDevicesException, IOException
{
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) {
sendLocalMessage(source, destination, destinationDevice.get(), incomingMessage);
}
}
}
private void sendLocalMessage(Account source,
Account destinationAccount,
Device destinationDevice,
IncomingMessage incomingMessage)
throws NoSuchUserException, IOException
{
try {
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
OutgoingMessageSignal.Builder messageBuilder = OutgoingMessageSignal.newBuilder();
messageBuilder.setType(incomingMessage.getType())
.setSource(source.getNumber())
.setTimestamp(System.currentTimeMillis());
if (messageBody.isPresent()) {
messageBuilder.setMessage(ByteString.copyFrom(messageBody.get()));
}
if (source.getRelay().isPresent()) {
messageBuilder.setRelay(source.getRelay().get());
}
pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build());
} catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) throw new NoSuchUserException(e);
else logger.debug("Not registered", e);
} catch (TransientPushFailureException e) {
if (destinationDevice.isMaster()) throw new IOException(e);
else logger.debug("Transient failure", e);
}
}
private void sendRelayMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws IOException, NoSuchUserException
{
try {
FederatedClient client = federatedClientManager.getClient(messages.getRelay());
client.sendMessages(source.getNumber(), destinationName, messages);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountsManager.get(destination);
if (!account.isPresent() || !account.get().isActive()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
throws MissingDevicesException
{
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>();
Set<Long> destinationDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages));
Set<String> destinationNumbers = new HashSet<>();
for (IncomingMessage incoming : incomingMessages)
destinationNumbers.add(incoming.getDestination());
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
outgoingMessage.setType(incoming.getType());
outgoingMessage.setSource(sourceNumber);
byte[] messageBody = getMessageBody(incoming);
if (messageBody != null) {
outgoingMessage.setMessage(ByteString.copyFrom(messageBody));
}
outgoingMessage.setTimestamp(System.currentTimeMillis());
for (String destination : destinationNumbers) {
if (!destination.equals(incoming.getDestination()))
outgoingMessage.addDestinations(destination);
}
LocalOrRemoteDevice device = null;
if (!Util.isEmpty(incoming.getRelay()))
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
else {
Account destination = null;
for (Account account : localAccounts) {
if (account.getNumber().equals(incoming.getDestination())) {
destination = account;
break;
}
}
if (destination != null)
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId()));
}
if (device != null)
outgoingMessages.add(new Pair<>(device, outgoingMessage.build()));
for (IncomingMessage message : messages) {
destinationDeviceIds.add(message.getDestinationDeviceId());
}
return outgoingMessages;
}
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
for (Device device : account.getDevices()) {
if (!destinationDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
deviceIds.add(incoming.getDestinationDeviceId());
}
return localDestinations;
}
private byte[] getMessageBody(IncomingMessage message) {
try {
return Base64.decode(message.getBody());
} catch (IOException ioe) {
ioe.printStackTrace();
return null;
if (!missingDeviceIds.isEmpty()) {
throw new MissingDevicesException(missingDeviceIds);
}
}
private byte[] serializeResponse(MessageResponse response) throws IOException {
try {
return objectMapper.writeValueAsBytes(response);
} catch (JsonProcessingException e) {
throw new IOException(e);
}
}
private IncomingMessageList parseIncomingMessages(HttpServletRequest request)
throws IOException, ValidationException
private void validateLegacyDestinations(List<IncomingMessage> messages)
throws ValidationException
{
BufferedReader reader = request.getReader();
StringBuilder content = new StringBuilder();
String line;
String destination = null;
while ((line = reader.readLine()) != null) {
content.append(line);
for (IncomingMessage message : messages) {
if (destination != null && !destination.equals(message.getDestination())) {
throw new ValidationException("Multiple account destinations!");
}
destination = message.getDestination();
}
IncomingMessageList messages = objectMapper.readValue(content.toString(),
IncomingMessageList.class);
if (messages.getMessages() == null) {
throw new ValidationException();
}
for (IncomingMessage message : messages.getMessages()) {
if (message.getBody() == null) throw new ValidationException();
if (message.getDestination() == null) throw new ValidationException();
}
return messages;
}
private Device authenticate(HttpServletRequest request) throws AuthenticationException {
private Optional<byte[]> getMessageBody(IncomingMessage message) {
try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromFullHeader(request.getHeader("Authorization"));
BasicCredentials credentials = new BasicCredentials(authorizationHeader.getNumber() + "." + authorizationHeader.getDeviceId(),
authorizationHeader.getPassword() );
Optional<Device> account = deviceAuthenticator.authenticate(credentials);
if (account.isPresent()) return account.get();
else throw new AuthenticationException("Bad credentials");
} catch (InvalidAuthorizationHeaderException e) {
throw new AuthenticationException(e);
return Optional.of(Base64.decode(message.getBody()));
} catch (IOException ioe) {
logger.debug("Bad B64", ioe);
return Optional.absent();
}
}
// @Timed
// @POST
// @Consumes(MediaType.APPLICATION_JSON)
// @Produces(MediaType.APPLICATION_JSON)
// public MessageResponse sendMessage(@Auth Device sender, IncomingMessageList messages)
// throws IOException
// {
// List<String> success = new LinkedList<>();
// List<String> failure = new LinkedList<>();
// List<IncomingMessage> incomingMessages = messages.getMessages();
// List<OutgoingMessageSignal> outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), incomingMessages);
//
// IterablePair<IncomingMessage, OutgoingMessageSignal> listPair = new IterablePair<>(incomingMessages, outgoingMessages);
//
// for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : listPair) {
// String destination = messagePair.first().getDestination();
// String relay = messagePair.first().getRelay();
//
// try {
// if (Util.isEmpty(relay)) sendLocalMessage(destination, messagePair.second());
// else sendRelayMessage(relay, destination, messagePair.second());
// success.add(destination);
// } catch (NoSuchUserException e) {
// logger.debug("No such user", e);
// failure.add(destination);
// }
// }
//
// return new MessageResponse(success, failure);
// }
}

View File

@ -4,8 +4,13 @@ import java.util.List;
import java.util.Set;
public class MissingDevicesException extends Exception {
public Set<String> missingNumbers;
public MissingDevicesException(Set<String> missingNumbers) {
this.missingNumbers = missingNumbers;
private final List<Long> missingDevices;
public MissingDevicesException(List<Long> missingDevices) {
this.missingDevices = missingDevices;
}
public List<Long> getMissingDevices() {
return missingDevices;
}
}

View File

@ -18,4 +18,7 @@ package org.whispersystems.textsecuregcm.controllers;
public class ValidationException extends Exception {
public ValidationException(String s) {
super(s);
}
}

View File

@ -0,0 +1,13 @@
package org.whispersystems.textsecuregcm.entities;
public class CryptoEncodingException extends Exception {
public CryptoEncodingException(String s) {
super(s);
}
public CryptoEncodingException(Exception e) {
super(e);
}
}

View File

@ -0,0 +1,21 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
public class DeviceResponse {
@JsonProperty
private long deviceId;
@VisibleForTesting
public DeviceResponse() {}
public DeviceResponse(long deviceId) {
this.deviceId = deviceId;
}
public long getDeviceId() {
return deviceId;
}
}

View File

@ -51,7 +51,7 @@ public class EncryptedOutgoingMessage {
this.signalingKey = signalingKey;
}
public String serialize() throws IOException {
public String serialize() throws CryptoEncodingException {
byte[] plaintext = outgoingMessage.toByteArray();
SecretKeySpec cipherKey = getCipherKey (signalingKey);
SecretKeySpec macKey = getMacKey(signalingKey);
@ -61,7 +61,7 @@ public class EncryptedOutgoingMessage {
}
private byte[] getCiphertext(byte[] plaintext, SecretKeySpec cipherKey, SecretKeySpec macKey)
throws IOException
throws CryptoEncodingException
{
try {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
@ -85,31 +85,39 @@ public class EncryptedOutgoingMessage {
throw new AssertionError(e);
} catch (InvalidKeyException e) {
logger.warn("Invalid Key", e);
throw new IOException("Invalid key!");
throw new CryptoEncodingException("Invalid key!");
}
}
private SecretKeySpec getCipherKey(String signalingKey) throws IOException {
byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] cipherKey = new byte[CIPHER_KEY_SIZE];
private SecretKeySpec getCipherKey(String signalingKey) throws CryptoEncodingException {
try {
byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] cipherKey = new byte[CIPHER_KEY_SIZE];
if (signalingKeyBytes.length < CIPHER_KEY_SIZE)
throw new IOException("Signaling key too short!");
if (signalingKeyBytes.length < CIPHER_KEY_SIZE)
throw new CryptoEncodingException("Signaling key too short!");
System.arraycopy(signalingKeyBytes, 0, cipherKey, 0, cipherKey.length);
return new SecretKeySpec(cipherKey, "AES");
System.arraycopy(signalingKeyBytes, 0, cipherKey, 0, cipherKey.length);
return new SecretKeySpec(cipherKey, "AES");
} catch (IOException e) {
throw new CryptoEncodingException(e);
}
}
private SecretKeySpec getMacKey(String signalingKey) throws IOException {
byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] macKey = new byte[MAC_KEY_SIZE];
private SecretKeySpec getMacKey(String signalingKey) throws CryptoEncodingException {
try {
byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] macKey = new byte[MAC_KEY_SIZE];
if (signalingKeyBytes.length < CIPHER_KEY_SIZE + MAC_KEY_SIZE)
throw new IOException(("Signaling key too short!"));
if (signalingKeyBytes.length < CIPHER_KEY_SIZE + MAC_KEY_SIZE)
throw new CryptoEncodingException("Signaling key too short!");
System.arraycopy(signalingKeyBytes, CIPHER_KEY_SIZE, macKey, 0, macKey.length);
System.arraycopy(signalingKeyBytes, CIPHER_KEY_SIZE, macKey, 0, macKey.length);
return new SecretKeySpec(macKey, "HmacSHA256");
return new SecretKeySpec(macKey, "HmacSHA256");
} catch (IOException e) {
throw new CryptoEncodingException(e);
}
}
}

View File

@ -29,9 +29,20 @@ public class IncomingMessageList {
@Valid
private List<IncomingMessage> messages;
@JsonProperty
private String relay;
public IncomingMessageList() {}
public List<IncomingMessage> getMessages() {
return messages;
}
public String getRelay() {
return relay;
}
public void setRelay(String relay) {
this.relay = relay;
}
}

View File

@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
public class MissingDevices {
@JsonProperty
public List<Long> missingDevices;
public MissingDevices(List<Long> missingDevices) {
this.missingDevices = missingDevices;
}
}

View File

@ -35,7 +35,6 @@ public class PreKey {
private String number;
@JsonProperty
@NotNull
private long deviceId;
@JsonProperty

View File

@ -23,14 +23,24 @@ import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
public class UnstructuredPreKeyList {
@JsonProperty
@NotNull
@Valid
private List<PreKey> keys;
@VisibleForTesting
public UnstructuredPreKeyList() {}
public UnstructuredPreKeyList(PreKey preKey) {
this.keys = new LinkedList<PreKey>();
this.keys.add(preKey);
}
public UnstructuredPreKeyList(List<PreKey> preKeys) {
this.keys = preKeys;
}
@ -39,7 +49,8 @@ public class UnstructuredPreKeyList {
return keys;
}
@VisibleForTesting public boolean equals(Object o) {
@VisibleForTesting
public boolean equals(Object o) {
if (!(o instanceof UnstructuredPreKeyList) ||
((UnstructuredPreKeyList) o).keys.size() != keys.size())
return false;

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.federation;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.Client;
import com.sun.jersey.api.client.ClientHandlerException;
import com.sun.jersey.api.client.ClientResponse;
@ -34,14 +35,19 @@ import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.util.Base64;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
@ -54,6 +60,7 @@ import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.Map;
public class FederatedClient {
@ -61,8 +68,9 @@ public class FederatedClient {
private static final String USER_COUNT_PATH = "/v1/federation/user_count";
private static final String USER_TOKENS_PATH = "/v1/federation/user_tokens/%d";
private static final String RELAY_MESSAGE_PATH = "/v1/federation/message";
private static final String RELAY_MESSAGE_PATH = "/v1/federation/messages/%s/%s";
private static final String PREKEY_PATH = "/v1/federation/key/%s";
private static final String PREKEY_PATH_DEVICE = "/v1/federation/key/%s/%s";
private static final String ATTACHMENT_URI_PATH = "/v1/federation/attachment/%d";
private final FederatedPeer peer;
@ -98,15 +106,27 @@ public class FederatedClient {
}
}
public UnstructuredPreKeyList getKeys(String destination) {
public Optional<PreKey> getKey(String destination) {
try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH, destination));
return resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.get(UnstructuredPreKeyList.class);
return Optional.of(resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.get(PreKey.class));
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e);
return null;
return Optional.absent();
}
}
public Optional<UnstructuredPreKeyList> getKeys(String destination, String device) {
try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH_DEVICE, destination, device));
return Optional.of(resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.get(UnstructuredPreKeyList.class));
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e);
return Optional.absent();
}
}
@ -138,21 +158,19 @@ public class FederatedClient {
}
}
public MessageResponse sendMessages(List<RelayMessage> messages)
public void sendMessages(String source, String destination, IncomingMessageList messages)
throws IOException
{
try {
WebResource resource = client.resource(peer.getUrl()).path(RELAY_MESSAGE_PATH);
WebResource resource = client.resource(peer.getUrl()).path(String.format(RELAY_MESSAGE_PATH, source, destination));
ClientResponse response = resource.type(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.entity(messages)
.put(ClientResponse.class);
if (response.getStatus() != 200 && response.getStatus() != 204) {
throw new IOException("Bad response: " + response.getStatus());
throw new WebApplicationException(clientResponseToResponse(response));
}
return response.getEntity(MessageResponse.class);
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("sendMessage", e);
throw new IOException(e);
@ -203,6 +221,19 @@ public class FederatedClient {
}
}
private Response clientResponseToResponse(ClientResponse r) {
Response.ResponseBuilder rb = Response.status(r.getStatus());
for (Map.Entry<String, List<String>> entry : r.getHeaders().entrySet()) {
for (String value : entry.getValue()) {
rb.header(entry.getKey(), value);
}
}
rb.entity(r.getEntityInputStream());
return rb.build();
}
public String getPeerName() {
return peer.getName();
}

View File

@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.federation;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.storage.Account;
public class NonLimitedAccount extends Account {
@JsonIgnore
private final String number;
@JsonIgnore
private final String relay;
public NonLimitedAccount(String number, String relay) {
this.number = number;
this.relay = relay;
}
public String getNumber() {
return number;
}
public boolean isRateLimited() {
return false;
}
public Optional<String> getRelay() {
return Optional.of(relay);
}
}

View File

@ -59,6 +59,7 @@ public class RateLimiters {
this.messagesLimiter = new RateLimiter(memcachedClient, "messages",
config.getMessages().getBucketSize(),
config.getMessages().getLeakRatePerMinute());
}
public RateLimiter getMessagesLimiter() {

View File

@ -25,6 +25,7 @@ import com.yammer.metrics.core.Meter;
import org.bouncycastle.openssl.PEMReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.util.Util;
@ -32,7 +33,6 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.MalformedURLException;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.KeyStoreException;
@ -66,12 +66,12 @@ public class APNSender {
}
public void sendMessage(String registrationId, EncryptedOutgoingMessage message)
throws IOException
throws TransientPushFailureException, NotPushRegisteredException
{
try {
if (!apnService.isPresent()) {
failure.mark();
throw new IOException("APN access not configured!");
throw new TransientPushFailureException("APN access not configured!");
}
String payload = APNS.newPayload()
@ -83,12 +83,12 @@ public class APNSender {
apnService.get().push(registrationId, payload);
success.mark();
} catch (MalformedURLException mue) {
throw new AssertionError(mue);
} catch (NetworkIOException nioe) {
logger.warn("Network Error", nioe);
failure.mark();
throw new IOException("Error sending APN");
throw new TransientPushFailureException(nioe);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
}

View File

@ -22,7 +22,7 @@ import com.google.android.gcm.server.Result;
import com.google.android.gcm.server.Sender;
import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import java.io.IOException;
@ -40,24 +40,30 @@ public class GCMSender {
}
public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage)
throws IOException, NoSuchUserException
throws NotPushRegisteredException, TransientPushFailureException
{
Message gcmMessage = new Message.Builder().addData("type", "message")
.addData("message", outgoingMessage.serialize())
.build();
try {
Message gcmMessage = new Message.Builder().addData("type", "message")
.addData("message", outgoingMessage.serialize())
.build();
Result result = sender.send(gcmMessage, gcmRegistrationId, 5);
Result result = sender.send(gcmMessage, gcmRegistrationId, 5);
if (result.getMessageId() != null) {
success.mark();
return result.getCanonicalRegistrationId();
} else {
failure.mark();
if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) {
throw new NoSuchUserException("User no longer registered with GCM.");
if (result.getMessageId() != null) {
success.mark();
return result.getCanonicalRegistrationId();
} else {
throw new IOException("GCM Failed: " + result.getErrorCodeName());
failure.mark();
if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) {
throw new NotPushRegisteredException("Device no longer registered with GCM.");
} else {
throw new TransientPushFailureException("GCM Failed: " + result.getErrorCodeName());
}
}
} catch (IOException e) {
throw new TransientPushFailureException(e);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
}
}

View File

@ -0,0 +1,11 @@
package org.whispersystems.textsecuregcm.push;
public class NotPushRegisteredException extends Exception {
public NotPushRegisteredException(String s) {
super(s);
}
public NotPushRegisteredException(Exception e) {
super(e);
}
}

View File

@ -16,92 +16,94 @@
*/
package org.whispersystems.textsecuregcm.push;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.util.Pair;
import java.io.IOException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class PushSender {
private final Logger logger = LoggerFactory.getLogger(PushSender.class);
private final AccountsManager accounts;
private final GCMSender gcmSender;
private final APNSender apnSender;
private final AccountsManager accounts;
private final GCMSender gcmSender;
private final APNSender apnSender;
private final StoredMessageManager storedMessageManager;
public PushSender(GcmConfiguration gcmConfiguration,
ApnConfiguration apnConfiguration,
StoredMessageManager storedMessageManager,
AccountsManager accounts,
DirectoryManager directory)
AccountsManager accounts)
throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{
this.accounts = accounts;
this.accounts = accounts;
this.storedMessageManager = storedMessageManager;
this.gcmSender = new GCMSender(gcmConfiguration.getApiKey());
this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey());
}
public void sendMessage(Device device, MessageProtos.OutgoingMessageSignal outgoingMessage)
throws IOException, NoSuchUserException
public void sendMessage(Account account, Device device, MessageProtos.OutgoingMessageSignal outgoingMessage)
throws NotPushRegisteredException, TransientPushFailureException
{
String signalingKey = device.getSignalingKey();
EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey);
String signalingKey = device.getSignalingKey();
EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey);
if (device.getGcmRegistrationId() != null) sendGcmMessage(device, message);
else if (device.getApnRegistrationId() != null) sendApnMessage(device, message);
else if (device.getFetchesMessages()) storeFetchedMessage(device, message);
else throw new NoSuchUserException("No push identifier!");
if (device.getGcmId() != null) sendGcmMessage(account, device, message);
else if (device.getApnId() != null) sendApnMessage(account, device, message);
else if (device.getFetchesMessages()) storeFetchedMessage(device, message);
else throw new NotPushRegisteredException("No delivery possible!");
}
private void sendGcmMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws IOException, NoSuchUserException
private void sendGcmMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws NotPushRegisteredException, TransientPushFailureException
{
try {
String canonicalId = gcmSender.sendMessage(device.getGcmRegistrationId(),
outgoingMessage);
String canonicalId = gcmSender.sendMessage(device.getGcmId(), outgoingMessage);
if (canonicalId != null) {
device.setGcmRegistrationId(canonicalId);
accounts.update(device);
device.setGcmId(canonicalId);
accounts.update(account);
}
} catch (NoSuchUserException e) {
} catch (NotPushRegisteredException e) {
logger.debug("No Such User", e);
device.setGcmRegistrationId(null);
accounts.update(device);
throw new NoSuchUserException("User no longer exists in GCM.");
device.setGcmId(null);
accounts.update(account);
throw new NotPushRegisteredException(e);
}
}
private void sendApnMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws IOException
private void sendApnMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws TransientPushFailureException, NotPushRegisteredException
{
apnSender.sendMessage(device.getApnRegistrationId(), outgoingMessage);
try {
apnSender.sendMessage(device.getApnId(), outgoingMessage);
} catch (NotPushRegisteredException e) {
device.setApnId(null);
accounts.update(account);
throw new NotPushRegisteredException(e);
}
}
private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage) throws IOException {
storedMessageManager.storeMessage(device, outgoingMessage);
private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws NotPushRegisteredException
{
try {
storedMessageManager.storeMessage(device, outgoingMessage);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
}
}

View File

@ -0,0 +1,11 @@
package org.whispersystems.textsecuregcm.push;
public class TransientPushFailureException extends Exception {
public TransientPushFailureException(String s) {
super(s);
}
public TransientPushFailureException(Exception e) {
super(e);
}
}

View File

@ -17,35 +17,43 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.util.Util;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Optional;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class Account implements Serializable {
private String number;
private boolean supportsSms;
private Map<Long, Device> devices = new HashMap<>();
private Account(String number, boolean supportsSms) {
public static final int MEMCACHE_VERION = 2;
@JsonProperty
private String number;
@JsonProperty
private boolean supportsSms;
@JsonProperty
private List<Device> devices = new LinkedList<>();
@JsonIgnore
private Optional<Device> authenticatedDevice;
public Account() {}
public Account(String number, boolean supportsSms) {
this.number = number;
this.supportsSms = supportsSms;
}
public Account(String number, boolean supportsSms, Device onlyDevice) {
this(number, supportsSms);
addDevice(onlyDevice);
public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice;
}
public Account(String number, boolean supportsSms, List<Device> devices) {
this(number, supportsSms);
for (Device device : devices)
addDevice(device);
public void setAuthenticatedDevice(Device device) {
this.authenticatedDevice = Optional.of(device);
}
public void setNumber(String number) {
@ -64,30 +72,55 @@ public class Account implements Serializable {
this.supportsSms = supportsSms;
}
public boolean isActive() {
Device masterDevice = devices.get((long) 1);
return masterDevice != null && masterDevice.isActive();
public void addDevice(Device device) {
this.devices.add(device);
}
public Collection<Device> getDevices() {
return devices.values();
public void setDevices(List<Device> devices) {
this.devices = devices;
}
public Device getDevice(long destinationDeviceId) {
return devices.get(destinationDeviceId);
public List<Device> getDevices() {
return devices;
}
public boolean hasAllDeviceIds(Set<Long> deviceIds) {
if (devices.size() != deviceIds.size())
return false;
for (long deviceId : devices.keySet()) {
if (!deviceIds.contains(deviceId))
return false;
public Optional<Device> getMasterDevice() {
return getDevice(Device.MASTER_ID);
}
public Optional<Device> getDevice(long deviceId) {
for (Device device : devices) {
if (device.getId() == deviceId) {
return Optional.of(device);
}
}
return Optional.absent();
}
public boolean isActive() {
return
getMasterDevice().isPresent() &&
getMasterDevice().get().isActive();
}
public long getNextDeviceId() {
long highestDevice = Device.MASTER_ID;
for (Device device : devices) {
if (device.getId() > highestDevice) {
highestDevice = device.getId();
}
}
return highestDevice + 1;
}
public boolean isRateLimited() {
return true;
}
public void addDevice(Device device) {
devices.put(device.getDeviceId(), device);
public Optional<String> getRelay() {
return Optional.absent();
}
}

View File

@ -16,6 +16,10 @@
*/
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel;
@ -28,10 +32,9 @@ import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.sqlobject.stringtemplate.UseStringTemplate3StatementLocator;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.skife.jdbi.v2.unstable.BindIn;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
@ -39,91 +42,63 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
@UseStringTemplate3StatementLocator
public abstract class Accounts {
public static final String ID = "id";
public static final String NUMBER = "number";
public static final String DEVICE_ID = "device_id";
public static final String AUTH_TOKEN = "auth_token";
public static final String SALT = "salt";
public static final String SIGNALING_KEY = "signaling_key";
public static final String GCM_ID = "gcm_id";
public static final String APN_ID = "apn_id";
public static final String FETCHES_MESSAGES = "fetches_messages";
public static final String SUPPORTS_SMS = "supports_sms";
private static final String ID = "id";
private static final String NUMBER = "number";
private static final String DATA = "data";
@SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DEVICE_ID + ", " + AUTH_TOKEN + ", " +
SALT + ", " + SIGNALING_KEY + ", " + FETCHES_MESSAGES + ", " +
GCM_ID + ", " + APN_ID + ", " + SUPPORTS_SMS + ") " +
"VALUES (:number, :device_id, :auth_token, :salt, :signaling_key, :fetches_messages, :gcm_id, :apn_id, :supports_sms)")
@GetGeneratedKeys
abstract long insertStep(@AccountBinder Device device);
private static final ObjectMapper mapper = new ObjectMapper();
@SqlQuery("SELECT " + DEVICE_ID + " FROM accounts WHERE " + NUMBER + " = :number ORDER BY " + DEVICE_ID + " DESC LIMIT 1 FOR UPDATE")
abstract long getHighestDeviceId(@Bind("number") String number);
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public long insert(@AccountBinder Device device) {
device.setDeviceId(getHighestDeviceId(device.getNumber()) + 1);
return insertStep(device);
static {
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
}
@SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number RETURNING id")
abstract void removeAccountsByNumber(@Bind("number") String number);
@SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
@GetGeneratedKeys
abstract long insertStep(@AccountBinder Account account);
@SqlUpdate("UPDATE accounts SET " + AUTH_TOKEN + " = :auth_token, " + SALT + " = :salt, " +
SIGNALING_KEY + " = :signaling_key, " + GCM_ID + " = :gcm_id, " + APN_ID + " = :apn_id, " +
FETCHES_MESSAGES + " = :fetches_messages, " + SUPPORTS_SMS + " = :supports_sms " +
"WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id")
abstract void update(@AccountBinder Device device);
@SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number")
abstract void removeAccount(@Bind("number") String number);
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id")
abstract Device get(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number")
abstract void update(@AccountBinder Account account);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
abstract Account get(@Bind("number") String number);
@SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts")
abstract long getNumberCount();
abstract long getCount();
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1 OFFSET :offset LIMIT :limit")
abstract List<Device> getAllMasterDevices(@Bind("offset") int offset, @Bind("limit") int length);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts OFFSET :offset LIMIT :limit")
abstract List<Account> getAll(@Bind("offset") int offset, @Bind("limit") int length);
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1")
public abstract Iterator<Device> getAllMasterDevices();
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
public abstract List<Device> getAllByNumber(@Bind("number") String number);
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " IN ( <numbers> )")
public abstract List<Device> getAllByNumbers(@BindIn("numbers") List<String> numbers);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts")
public abstract Iterator<Account> getAll();
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public long insertClearingNumber(Device device) {
removeAccountsByNumber(device.getNumber());
device.setDeviceId(getHighestDeviceId(device.getNumber()) + 1);
return insertStep(device);
public long create(Account account) {
removeAccount(account.getNumber());
return insertStep(account);
}
public static class DeviceMapper implements ResultSetMapper<Device> {
public static class AccountMapper implements ResultSetMapper<Account> {
@Override
public Device map(int i, ResultSet resultSet, StatementContext statementContext)
public Account map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new Device(resultSet.getLong(ID), resultSet.getString(NUMBER), resultSet.getLong(DEVICE_ID),
resultSet.getString(AUTH_TOKEN), resultSet.getString(SALT),
resultSet.getString(SIGNALING_KEY), resultSet.getString(GCM_ID),
resultSet.getString(APN_ID),
resultSet.getInt(SUPPORTS_SMS) == 1, resultSet.getInt(FETCHES_MESSAGES) == 1);
try {
return mapper.readValue(resultSet.getString(DATA), Account.class);
} catch (IOException e) {
throw new SQLException(e);
}
}
}
@ -134,23 +109,20 @@ public abstract class Accounts {
public static class AccountBinderFactory implements BinderFactory {
@Override
public Binder build(Annotation annotation) {
return new Binder<AccountBinder, Device>() {
return new Binder<AccountBinder, Account>() {
@Override
public void bind(SQLStatement<?> sql,
AccountBinder accountBinder,
Device device)
Account account)
{
sql.bind(ID, device.getId());
sql.bind(NUMBER, device.getNumber());
sql.bind(DEVICE_ID, device.getDeviceId());
sql.bind(AUTH_TOKEN, device.getAuthenticationCredentials()
.getHashedAuthenticationToken());
sql.bind(SALT, device.getAuthenticationCredentials().getSalt());
sql.bind(SIGNALING_KEY, device.getSignalingKey());
sql.bind(GCM_ID, device.getGcmRegistrationId());
sql.bind(APN_ID, device.getApnRegistrationId());
sql.bind(SUPPORTS_SMS, device.getSupportsSms() ? 1 : 0);
sql.bind(FETCHES_MESSAGES, device.getFetchesMessages() ? 1 : 0);
try {
String serialized = mapper.writeValueAsString(account);
sql.bind(NUMBER, account.getNumber());
sql.bind(DATA, serialized);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
};
}

View File

@ -49,131 +49,67 @@ public class AccountsManager {
}
public long getCount() {
return accounts.getNumberCount();
return accounts.getCount();
}
public List<Device> getAllMasterDevices(int offset, int length) {
return accounts.getAllMasterDevices(offset, length);
public List<Account> getAll(int offset, int length) {
return accounts.getAll(offset, length);
}
public Iterator<Device> getAllMasterDevices() {
return accounts.getAllMasterDevices();
public Iterator<Account> getAll() {
return accounts.getAll();
}
/** Creates a new Account (WITH ONE DEVICE), clearing all existing devices on the given number */
public void create(Account account) {
Device device = account.getDevices().iterator().next();
long id = accounts.insertClearingNumber(device);
device.setId(id);
accounts.create(account);
if (memcachedClient != null) {
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device);
memcachedClient.set(getKey(account.getNumber()), 0, account);
}
updateDirectory(device);
updateDirectory(account);
}
/** Creates a new Device for an existing Account */
public void provisionDevice(Device device) {
long id = accounts.insert(device);
device.setId(id);
public void update(Account account) {
if (memcachedClient != null) {
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device);
memcachedClient.set(getKey(account.getNumber()), 0, account);
}
updateDirectory(device);
accounts.update(account);
updateDirectory(account);
}
public void update(Device device) {
if (memcachedClient != null) {
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device);
}
accounts.update(device);
updateDirectory(device);
}
public Optional<Device> get(String number, long deviceId) {
Device device = null;
public Optional<Account> get(String number) {
Account account = null;
if (memcachedClient != null) {
device = (Device)memcachedClient.get(getKey(number, deviceId));
account = (Account)memcachedClient.get(getKey(number));
}
if (device == null) {
device = accounts.get(number, deviceId);
if (account == null) {
account = accounts.get(number);
if (device != null && memcachedClient != null) {
memcachedClient.set(getKey(number, deviceId), 0, device);
if (account != null && memcachedClient != null) {
memcachedClient.set(getKey(number), 0, account);
}
}
if (device != null) return Optional.of(device);
if (account != null) return Optional.of(account);
else return Optional.absent();
}
public Optional<Account> getAccount(String number) {
List<Device> devices = accounts.getAllByNumber(number);
if (devices.isEmpty())
return Optional.absent();
return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices));
}
private List<Account> getAllAccounts(List<String> numbers) {
List<Device> devices = accounts.getAllByNumbers(numbers);
List<Account> accounts = new LinkedList<>();
for (Device device : devices) {
Account deviceAccount = null;
for (Account account : accounts) {
if (account.getNumber().equals(device.getNumber())) {
deviceAccount = account;
break;
}
}
if (deviceAccount == null) {
deviceAccount = new Account(device.getNumber(), false, device);
accounts.add(deviceAccount);
} else {
deviceAccount.addDevice(device);
}
if (device.getDeviceId() == 1)
deviceAccount.setSupportsSms(device.getSupportsSms());
}
return accounts;
}
public List<Account> getAccountsForDevices(Map<String, Set<Long>> destinations) throws MissingDevicesException {
Set<String> numbersMissingDevices = new HashSet<>(destinations.keySet());
List<Account> localAccounts = getAllAccounts(new LinkedList<>(destinations.keySet()));
for (Account account : localAccounts){
if (account.hasAllDeviceIds(destinations.get(account.getNumber())))
numbersMissingDevices.remove(account.getNumber());
}
if (!numbersMissingDevices.isEmpty())
throw new MissingDevicesException(numbersMissingDevices);
return localAccounts;
}
private void updateDirectory(Device device) {
if (device.getDeviceId() != 1)
return;
if (device.isActive()) {
byte[] token = Util.getContactToken(device.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms());
private void updateDirectory(Account account) {
if (account.isActive()) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
directory.add(clientContact);
} else {
directory.remove(device.getNumber());
directory.remove(account.getNumber());
}
}
private String getKey(String number, long accountId) {
return Device.class.getSimpleName() + Device.MEMCACHE_VERION + number + accountId;
private String getKey(String number) {
return Account.class.getSimpleName() + Account.MEMCACHE_VERION + number;
}
}

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.util.Util;
@ -24,97 +25,58 @@ import java.io.Serializable;
public class Device implements Serializable {
public static final int MEMCACHE_VERION = 1;
public static final long MASTER_ID = 1;
@JsonProperty
private long id;
private String number;
private long deviceId;
private String hashedAuthenticationToken;
@JsonProperty
private String authToken;
@JsonProperty
private String salt;
@JsonProperty
private String signalingKey;
/**
* In order for us to tell a client that an account is "inactive" (ie go use SMS for transport), we check that all
* non-fetching Accounts don't have push registrations. In this way, we can ensure that we have some form of transport
* available for all Accounts on all "active" numbers.
*/
private String gcmRegistrationId;
private String apnRegistrationId;
private boolean supportsSms;
@JsonProperty
private String gcmId;
@JsonProperty
private String apnId;
@JsonProperty
private boolean fetchesMessages;
public Device() {}
public Device(long id, String number, long deviceId, String hashedAuthenticationToken, String salt,
String signalingKey, String gcmRegistrationId, String apnRegistrationId,
boolean supportsSms, boolean fetchesMessages)
public Device(long id, String authToken, String salt,
String signalingKey, String gcmId, String apnId,
boolean fetchesMessages)
{
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.hashedAuthenticationToken = hashedAuthenticationToken;
this.salt = salt;
this.signalingKey = signalingKey;
this.gcmRegistrationId = gcmRegistrationId;
this.apnRegistrationId = apnRegistrationId;
this.supportsSms = supportsSms;
this.fetchesMessages = fetchesMessages;
this.id = id;
this.authToken = authToken;
this.salt = salt;
this.signalingKey = signalingKey;
this.gcmId = gcmId;
this.apnId = apnId;
this.fetchesMessages = fetchesMessages;
}
public String getApnRegistrationId() {
return apnRegistrationId;
public String getApnId() {
return apnId;
}
public void setApnRegistrationId(String apnRegistrationId) {
this.apnRegistrationId = apnRegistrationId;
public void setApnId(String apnId) {
this.apnId = apnId;
}
public String getGcmRegistrationId() {
return gcmRegistrationId;
public String getGcmId() {
return gcmId;
}
public void setGcmRegistrationId(String gcmRegistrationId) {
this.gcmRegistrationId = gcmRegistrationId;
}
public void setNumber(String number) {
this.number = number;
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public void setDeviceId(long deviceId) {
this.deviceId = deviceId;
}
public void setAuthenticationCredentials(AuthenticationCredentials credentials) {
this.hashedAuthenticationToken = credentials.getHashedAuthenticationToken();
this.salt = credentials.getSalt();
}
public AuthenticationCredentials getAuthenticationCredentials() {
return new AuthenticationCredentials(hashedAuthenticationToken, salt);
}
public String getSignalingKey() {
return signalingKey;
}
public void setSignalingKey(String signalingKey) {
this.signalingKey = signalingKey;
}
public boolean getSupportsSms() {
return supportsSms;
}
public void setSupportsSms(boolean supportsSms) {
this.supportsSms = supportsSms;
public void setGcmId(String gcmId) {
this.gcmId = gcmId;
}
public long getId() {
@ -125,8 +87,25 @@ public class Device implements Serializable {
this.id = id;
}
public void setAuthenticationCredentials(AuthenticationCredentials credentials) {
this.authToken = credentials.getHashedAuthenticationToken();
this.salt = credentials.getSalt();
}
public AuthenticationCredentials getAuthenticationCredentials() {
return new AuthenticationCredentials(authToken, salt);
}
public String getSignalingKey() {
return signalingKey;
}
public void setSignalingKey(String signalingKey) {
this.signalingKey = signalingKey;
}
public boolean isActive() {
return fetchesMessages || !Util.isEmpty(getApnRegistrationId()) || !Util.isEmpty(getGcmRegistrationId());
return fetchesMessages || !Util.isEmpty(getApnId()) || !Util.isEmpty(getGcmId());
}
public boolean getFetchesMessages() {
@ -137,7 +116,7 @@ public class Device implements Serializable {
this.fetchesMessages = fetchesMessages;
}
public String getBackwardsCompatibleNumberEncoding() {
return deviceId == 1 ? number : (number + "." + deviceId);
public boolean isMaster() {
return getId() == MASTER_ID;
}
}

View File

@ -16,6 +16,7 @@
*/
package org.whispersystems.textsecuregcm.storage;
import com.google.common.base.Optional;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel;
@ -62,8 +63,11 @@ public abstract class Keys {
@Mapper(PreKeyMapper.class)
abstract PreKey retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY key_id ASC FOR UPDATE")
abstract List<PreKey> retrieveFirst(@Bind("number") String number);
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, PreKey lastResortKey, List<PreKey> keys) {
public void store(String number, long deviceId, List<PreKey> keys, PreKey lastResortKey) {
for (PreKey key : keys) {
key.setNumber(number);
key.setDeviceId(deviceId);
@ -79,20 +83,31 @@ public abstract class Keys {
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public UnstructuredPreKeyList get(String number, Account account) {
List<PreKey> preKeys = new LinkedList<>();
for (Device device : account.getDevices()) {
PreKey preKey = retrieveFirst(number, device.getDeviceId());
if (preKey != null)
preKeys.add(preKey);
public Optional<PreKey> get(String number, long deviceId) {
PreKey preKey = retrieveFirst(number, deviceId);
if (preKey != null && !preKey.isLastResort()) {
removeKey(preKey.getId());
}
for (PreKey preKey : preKeys) {
if (!preKey.isLastResort())
removeKey(preKey.getId());
if (preKey != null) return Optional.of(preKey);
else return Optional.absent();
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<UnstructuredPreKeyList> get(String number) {
List<PreKey> preKeys = retrieveFirst(number);
if (preKeys != null) {
for (PreKey preKey : preKeys) {
if (!preKey.isLastResort()) {
removeKey(preKey.getId());
}
}
}
return new UnstructuredPreKeyList(preKeys);
if (preKeys != null) return Optional.of(new UnstructuredPreKeyList(preKeys));
else return Optional.absent();
}
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)

View File

@ -20,7 +20,7 @@ import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
public interface PendingDeviceRegistrations {
public interface PendingDevices {
@SqlUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code WHERE number = :number RETURNING *) " +
"INSERT INTO pending_devices (number, verification_code) SELECT :number, :verification_code WHERE NOT EXISTS (SELECT * FROM upsert)")

View File

@ -23,10 +23,10 @@ public class PendingDevicesManager {
private static final String MEMCACHE_PREFIX = "pending_devices";
private final PendingDeviceRegistrations pendingDevices;
private final MemcachedClient memcachedClient;
private final PendingDevices pendingDevices;
private final MemcachedClient memcachedClient;
public PendingDevicesManager(PendingDeviceRegistrations pendingDevices,
public PendingDevicesManager(PendingDevices pendingDevices,
MemcachedClient memcachedClient)
{
this.pendingDevices = pendingDevices;
@ -42,8 +42,10 @@ public class PendingDevicesManager {
}
public void remove(String number) {
if (memcachedClient != null)
if (memcachedClient != null) {
memcachedClient.delete(MEMCACHE_PREFIX + number);
}
pendingDevices.remove(number);
}

View File

@ -16,6 +16,7 @@
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import java.io.IOException;
@ -27,7 +28,9 @@ public class StoredMessageManager {
this.storedMessages = storedMessages;
}
public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage) throws IOException {
public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws CryptoEncodingException
{
storedMessages.insert(device.getId(), outgoingMessage.serialize());
}

View File

@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle;
@ -53,22 +53,23 @@ public class DirectoryUpdater {
BatchOperationHandle batchOperation = directory.startBatchOperation();
try {
Iterator<Device> accounts = accountsManager.getAllMasterDevices();
Iterator<Account> accounts = accountsManager.getAll();
if (accounts == null)
return;
while (accounts.hasNext()) {
Device device = accounts.next();
if (device.isActive()) {
byte[] token = Util.getContactToken(device.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms());
Account account = accounts.next();
if (account.isActive()) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
directory.add(batchOperation, clientContact);
logger.debug("Adding local token: " + Base64.encodeBytesWithoutPadding(token));
} else {
directory.remove(batchOperation, device.getNumber());
directory.remove(batchOperation, account.getNumber());
}
}
} finally {

View File

@ -77,19 +77,24 @@
</changeSet>
<changeSet id="2" author="matt">
<addColumn tableName="accounts">
<column name="device_id" type="bigint">
<constraints nullable="false" />
</column>
<column name="fetches_messages" type="smallint" defaultValue="0"/>
<addColumn tableName="accounts">
<column name="data" type="json" />
</addColumn>
<dropUniqueConstraint tableName="accounts" constraintName="accounts_number_key" />
<addUniqueConstraint constraintName="account_number_device_unique" tableName="accounts" columnNames="number, device_id" />
<sql>UPDATE accounts SET data = CAST(('{"number" : "' || number || '", "supportsSms" : ' || supports_sms || ', "devices" : [{"id" : 1, "authToken" : "' || auth_token || '", "salt" : "' || salt || '"' || CASE WHEN signaling_key IS NOT NULL THEN ', "signalingKey" : "' || signaling_key || '"' ELSE '' END || CASE WHEN gcm_id IS NOT NULL THEN ', "gcmId" : "' || gcm_id || '"' ELSE '' END || CASE WHEN apn_id IS NOT NULL THEN ', "apnId" : "' || apn_id || '"' ELSE '' END || '}]}') AS json);</sql>
<addNotNullConstraint tableName="accounts" columnName="data"/>
<dropColumn tableName="accounts" columnName="auth_token"/>
<dropColumn tableName="accounts" columnName="salt"/>
<dropColumn tableName="accounts" columnName="signaling_key"/>
<dropColumn tableName="accounts" columnName="gcm_id"/>
<dropColumn tableName="accounts" columnName="apn_id"/>
<dropColumn tableName="accounts" columnName="supports_sms"/>
<addColumn tableName="keys">
<column name="device_id" type="bigint" >
<column name="device_id" type="bigint" defaultValue="1">
<constraints nullable="false" />
</column>
</addColumn>

View File

@ -1,61 +1,26 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import com.yammer.dropwizard.testing.ResourceTest;
import org.hibernate.validator.constraints.NotEmpty;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import javax.ws.rs.Path;
import javax.ws.rs.core.MediaType;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.*;
public class AccountControllerTest extends ResourceTest {
/** The AccountAttributes used in protocol v1 (no fetchesMessages) */
static class V1AccountAttributes {
@JsonProperty
@NotEmpty
private String signalingKey;
@JsonProperty
private boolean supportsSms;
public V1AccountAttributes(String signalingKey, boolean supportsSms) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
}
}
@Path("/v1/accounts")
static class DumbVerificationAccountController extends AccountController {
public DumbVerificationAccountController(PendingAccountsManager pendingAccounts, AccountsManager accounts, RateLimiters rateLimiters, SmsSender smsSenderFactory) {
super(pendingAccounts, accounts, rateLimiters, smsSenderFactory);
}
@Override
protected VerificationCode generateVerificationCode() {
return new VerificationCode(5678901);
}
}
private static final String SENDER = "+14152222222";
@ -75,15 +40,10 @@ public class AccountControllerTest extends ResourceTest {
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of("1234"));
Mockito.doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
((Device)invocation.getArguments()[0]).setDeviceId(2);
return null;
}
}).when(accountsManager).provisionDevice(any(Device.class));
addResource(new DumbVerificationAccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
addResource(new AccountController(pendingAccountsManager,
accountsManager,
rateLimiters,
smsSender));
}
@Test
@ -102,17 +62,13 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1234"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new V1AccountAttributes("keykeykeykey", false))
.entity(new AccountAttributes("keykeykeykey", false, false))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).create(isA(Account.class));
ArgumentCaptor<String> number = ArgumentCaptor.forClass(String.class);
verify(pendingAccountsManager).remove(number.capture());
assertThat(number.getValue()).isEqualTo(SENDER);
}
@Test
@ -120,7 +76,7 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1111"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new V1AccountAttributes("keykeykeykey", false))
.entity(new AccountAttributes("keykeykeykey", false, false))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
@ -129,4 +85,4 @@ public class AccountControllerTest extends ResourceTest {
verifyNoMoreInteractions(accountsManager);
}
}
}

View File

@ -19,15 +19,12 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional;
import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -37,10 +34,7 @@ import javax.ws.rs.Path;
import javax.ws.rs.core.MediaType;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;
public class DeviceControllerTest extends ResourceTest {
@Path("/v1/devices")
@ -55,12 +49,11 @@ public class DeviceControllerTest extends ResourceTest {
}
}
private static final String SENDER = "+14152222222";
private PendingDevicesManager pendingDevicesManager = mock(PendingDevicesManager.class);
private AccountsManager accountsManager = mock(AccountsManager.class );
private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class );
private Account account = mock(Account.class );
@Override
protected void setUpResources() throws Exception {
@ -70,15 +63,10 @@ public class DeviceControllerTest extends ResourceTest {
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of("5678901"));
when(account.getNextDeviceId()).thenReturn(42L);
Mockito.doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
((Device) invocation.getArguments()[0]).setDeviceId(2);
return null;
}
}).when(accountsManager).provisionDevice(any(Device.class));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of("5678901"));
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, rateLimiters));
}
@ -91,19 +79,14 @@ public class DeviceControllerTest extends ResourceTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
Long deviceId = client().resource(String.format("/v1/devices/5678901"))
DeviceResponse response = client().resource("/v1/devices/5678901")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.entity(new AccountAttributes("keykeykeykey", false, true))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(Long.class);
assertThat(deviceId).isNotEqualTo(AuthHelper.DEFAULT_DEVICE_ID);
.put(DeviceResponse.class);
ArgumentCaptor<Device> newAccount = ArgumentCaptor.forClass(Device.class);
verify(accountsManager).provisionDevice(newAccount.capture());
assertThat(deviceId).isEqualTo(newAccount.getValue().getDeviceId());
assertThat(response.getDeviceId()).isEqualTo(42L);
ArgumentCaptor<String> number = ArgumentCaptor.forClass(String.class);
verify(pendingDevicesManager).remove(number.capture());
assertThat(number.getValue()).isEqualTo(AuthHelper.VALID_NUMBER);
verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
}
}

View File

@ -2,7 +2,6 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import com.sun.jersey.api.client.GenericType;
import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.KeysController;
@ -10,13 +9,10 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
@ -28,42 +24,32 @@ public class KeyControllerTest extends ResourceTest {
private final String EXISTS_NUMBER = "+14152222222";
private final String NOT_EXISTS_NUMBER = "+14152222220";
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, AuthHelper.DEFAULT_DEVICE_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false);
private final Keys keys = mock(Keys.class);
Device[] fakeDevice;
Account existsAccount;
@Override
protected void setUpResources() {
addProvider(AuthHelper.getAuthenticator());
RateLimiters rateLimiters = mock(RateLimiters.class);
RateLimiter rateLimiter = mock(RateLimiter.class );
AccountsManager accounts = mock(AccountsManager.class);
fakeDevice = new Device[2];
fakeDevice[0] = mock(Device.class);
fakeDevice[1] = mock(Device.class);
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(keys.get(eq(EXISTS_NUMBER), isA(Account.class))).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2)));
when(keys.get(eq(NOT_EXISTS_NUMBER), isA(Account.class))).thenReturn(null);
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY));
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<PreKey>absent());
when(fakeDevice[0].getDeviceId()).thenReturn(AuthHelper.DEFAULT_DEVICE_ID);
when(fakeDevice[1].getDeviceId()).thenReturn((long) 2);
List<PreKey> allKeys = new LinkedList<>();
allKeys.add(SAMPLE_KEY);
allKeys.add(SAMPLE_KEY2);
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(new UnstructuredPreKeyList(allKeys)));
when(accounts.getAccount(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount));
when(accounts.getAccount(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>absent());
addResource(new KeysController(rateLimiters, keys, accounts, null));
addResource(new KeysController(rateLimiters, keys, null));
}
@Test
public void validRequestsTest() throws Exception {
public void validLegacyRequestTest() throws Exception {
PreKey result = client().resource(String.format("/v1/keys/%s", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKey.class);
@ -75,15 +61,20 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
verify(keys).get(eq(EXISTS_NUMBER), eq(existsAccount));
verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys);
}
List<PreKey> results = client().resource(String.format("/v1/keys/%s?multikeys", EXISTS_NUMBER))
@Test
public void validMultiRequestTest() throws Exception {
UnstructuredPreKeyList results = client().resource(String.format("/v1/keys/%s/*", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(new GenericType<List<PreKey>>(){});
.get(UnstructuredPreKeyList.class);
assertThat(results.getKeys().size()).isEqualTo(2);
PreKey result = results.getKeys().get(0);
assertThat(results.size()).isEqualTo(2);
result = results.get(0);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY.getIdentityKey());
@ -91,18 +82,19 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
result = results.get(1);
result = results.getKeys().get(1);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY2.getIdentityKey());
assertThat(result.getId() == 1);
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
verify(keys, times(2)).get(eq(EXISTS_NUMBER), eq(existsAccount));
verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys);
}
@Test
public void invalidRequestTest() throws Exception {
ClientResponse response = client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
@ -110,7 +102,6 @@ public class KeyControllerTest extends ResourceTest {
.get(ClientResponse.class);
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404);
verifyNoMoreInteractions(keys);
}
@Test

View File

@ -1,56 +1,48 @@
package org.whispersystems.textsecuregcm.tests.util;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
import org.whispersystems.textsecuregcm.configuration.FederationConfiguration;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Base64;
import java.util.Arrays;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AuthHelper {
public static final long DEFAULT_DEVICE_ID = 1;
public static final String VALID_NUMBER = "+14150000000";
public static final String VALID_PASSWORD = "foo";
public static final String INVVALID_NUMBER = "+14151111111";
public static final String INVALID_PASSWORD = "bar";
public static final String VALID_FEDERATION_PEER = "valid_peer";
public static final String FEDERATION_PEER_TOKEN = "magic";
public static MultiBasicAuthProvider<FederatedPeer, Device> getAuthenticator() {
FederationConfiguration federationConfig = mock(FederationConfiguration.class);
when(federationConfig.getPeers()).thenReturn(Arrays.asList(new FederatedPeer(VALID_FEDERATION_PEER, "", FEDERATION_PEER_TOKEN, "")));
AccountsManager accounts = mock(AccountsManager.class);
Device device = mock(Device.class);
public static MultiBasicAuthProvider<FederatedPeer, Account> getAuthenticator() {
AccountsManager accounts = mock(AccountsManager.class );
Account account = mock(Account.class );
Device device = mock(Device.class );
AuthenticationCredentials credentials = mock(AuthenticationCredentials.class);
when(credentials.verify("foo")).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(accounts.get(VALID_NUMBER, DEFAULT_DEVICE_ID)).thenReturn(Optional.of(device));
when(account.getDevice(anyLong())).thenReturn(Optional.of(device));
when(accounts.get(VALID_NUMBER)).thenReturn(Optional.of(account));
return new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(federationConfig),
return new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(new FederationConfiguration()),
FederatedPeer.class,
new DeviceAuthenticator(accounts),
Device.class, "WhisperServer");
new AccountAuthenticator(accounts),
Account.class, "WhisperServer");
}
public static String getAuthHeader(String number, String password) {
return "Basic " + Base64.encodeBytes((number + ":" + password).getBytes());
}
public static String getV2AuthHeader(String number, long deviceId, String password) {
return "Basic " + Base64.encodeBytes((number + "." + deviceId + ":" + password).getBytes());
}
}

View File

@ -1,11 +0,0 @@
package org.whispersystems.textsecuregcm.tests;
/**
* Created with IntelliJ IDEA.
* User: moxie
* Date: 10/28/13
* Time: 12:53 PM
* To change this template use File | Settings | File Templates.
*/
public class BaseTest {
}

View File

@ -1,127 +0,0 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import com.sun.jersey.api.client.GenericType;
import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.FederationController;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.MissingDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import javax.ws.rs.core.MediaType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class FederatedControllerTest extends ResourceTest {
private final String EXISTS_NUMBER = "+14152222222";
private final String EXISTS_NUMBER_2 = "+14154444444";
private final String NOT_EXISTS_NUMBER = "+14152222220";
private final Keys keys = mock(Keys.class);
private final PushSender pushSender = mock(PushSender.class);
Device[] fakeDevice;
Account existsAccount;
@Override
protected void setUpResources() throws MissingDevicesException {
addProvider(AuthHelper.getAuthenticator());
RateLimiters rateLimiters = mock(RateLimiters.class);
RateLimiter rateLimiter = mock(RateLimiter.class );
AccountsManager accounts = mock(AccountsManager.class);
fakeDevice = new Device[2];
fakeDevice[0] = new Device(42, EXISTS_NUMBER, 1, "", "", "", null, null, true, false);
fakeDevice[1] = new Device(43, EXISTS_NUMBER, 2, "", "", "", null, null, false, true);
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
Map<String, Set<Long>> validOneElementSet = new HashMap<>();
validOneElementSet.put(EXISTS_NUMBER_2, new HashSet<>(Arrays.asList((long) 1)));
List<Account> validOneAccount = Arrays.asList(new Account(EXISTS_NUMBER_2, true,
Arrays.asList(new Device(44, EXISTS_NUMBER_2, 1, "", "", "", null, null, true, false))));
Map<String, Set<Long>> validTwoElementsSet = new HashMap<>();
validTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1, (long) 2)));
List<Account> validTwoAccount = Arrays.asList(new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])));
Map<String, Set<Long>> invalidTwoElementsSet = new HashMap<>();
invalidTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1)));
when(accounts.getAccountsForDevices(eq(validOneElementSet))).thenReturn(validOneAccount);
when(accounts.getAccountsForDevices(eq(validTwoElementsSet))).thenReturn(validTwoAccount);
when(accounts.getAccountsForDevices(eq(invalidTwoElementsSet))).thenThrow(new MissingDevicesException(new HashSet<>(Arrays.asList(EXISTS_NUMBER))));
addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class)));
}
@Test
public void validRequestsTest() throws Exception {
MessageResponse result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER_2, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER_2));
assertThat(result.getFailure()).isEmpty();
assertThat(result.getNumbersMissingDevices()).isEmpty();
result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()),
new RelayMessage(EXISTS_NUMBER, 2, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER, EXISTS_NUMBER + "." + 2));
assertThat(result.getFailure()).isEmpty();
assertThat(result.getNumbersMissingDevices()).isEmpty();
}
@Test
public void invalidRequestTest() throws Exception {
MessageResponse result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEmpty();
assertThat(result.getFailure()).isEqualTo(Arrays.asList(EXISTS_NUMBER));
assertThat(result.getNumbersMissingDevices()).isEqualTo(new HashSet<>(Arrays.asList(EXISTS_NUMBER)));
}
}