Initial multi device support refactoring.

1) Store account data as a json type, which includes all
   devices in a single object.

2) Simplify message delivery logic.

3) Make federated calls a pass through to standard controllers.

4) Simplify key retrieval logic.
This commit is contained in:
Moxie Marlinspike 2014-01-18 23:45:07 -08:00
parent 6f9226dcf9
commit 74f71fd8a6
47 changed files with 961 additions and 1211 deletions

View File

@ -108,12 +108,6 @@
<artifactId>jersey-json</artifactId> <artifactId>jersey-json</artifactId>
<version>1.17.1</version> <version>1.17.1</version>
</dependency> </dependency>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>stringtemplate</artifactId>
<version>3.2.1</version>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -27,7 +27,7 @@ import com.yammer.metrics.reporting.GraphiteReporter;
import net.spy.memcached.MemcachedClient; import net.spy.memcached.MemcachedClient;
import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.skife.jdbi.v2.DBI; import org.skife.jdbi.v2.DBI;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator; import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider; import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration; import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration;
@ -58,7 +58,7 @@ import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.PendingAccounts; import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDeviceRegistrations; import org.whispersystems.textsecuregcm.storage.PendingDevices;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager; import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager; import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.storage.StoredMessages;
@ -98,9 +98,9 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
Accounts accounts = jdbi.onDemand(Accounts.class); Accounts accounts = jdbi.onDemand(Accounts.class);
PendingAccounts pendingAccounts = jdbi.onDemand(PendingAccounts.class); PendingAccounts pendingAccounts = jdbi.onDemand(PendingAccounts.class);
PendingDeviceRegistrations pendingDevices = jdbi.onDemand(PendingDeviceRegistrations.class); PendingDevices pendingDevices = jdbi.onDemand(PendingDevices.class);
Keys keys = jdbi.onDemand(Keys.class); Keys keys = jdbi.onDemand(Keys.class);
StoredMessages storedMessages = jdbi.onDemand(StoredMessages.class); StoredMessages storedMessages = jdbi.onDemand(StoredMessages.class );
MemcachedClient memcachedClient = new MemcachedClientFactory(config.getMemcacheConfiguration()).getClient(); MemcachedClient memcachedClient = new MemcachedClientFactory(config.getMemcacheConfiguration()).getClient();
JedisPool redisClient = new RedisClientFactory(config.getRedisConfiguration()).getRedisClientPool(); JedisPool redisClient = new RedisClientFactory(config.getRedisConfiguration()).getRedisClientPool();
@ -109,10 +109,12 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, memcachedClient); PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, memcachedClient);
PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, memcachedClient); PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, memcachedClient);
AccountsManager accountsManager = new AccountsManager(accounts, directory, memcachedClient); AccountsManager accountsManager = new AccountsManager(accounts, directory, memcachedClient);
DeviceAuthenticator deviceAuthenticator = new DeviceAuthenticator(accountsManager );
FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration()); FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration());
StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages); StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient); RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration()); TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration());
Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration()); Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational()); SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());
@ -120,7 +122,11 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
PushSender pushSender = new PushSender(config.getGcmConfiguration(), PushSender pushSender = new PushSender(config.getGcmConfiguration(),
config.getApnConfiguration(), config.getApnConfiguration(),
storedMessageManager, storedMessageManager,
accountsManager, directory); accountsManager);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
environment.addProvider(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()), environment.addProvider(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),
FederatedPeer.class, FederatedPeer.class,
@ -130,13 +136,10 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
environment.addResource(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender)); environment.addResource(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
environment.addResource(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters)); environment.addResource(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
environment.addResource(new DirectoryController(rateLimiters, directory)); environment.addResource(new DirectoryController(rateLimiters, directory));
environment.addResource(new AttachmentController(rateLimiters, federatedClientManager, urlSigner)); environment.addResource(new FederationController(accountsManager, attachmentController, keysController, messageController));
environment.addResource(new KeysController(rateLimiters, keys, accountsManager, federatedClientManager)); environment.addResource(attachmentController);
environment.addResource(new FederationController(keys, accountsManager, pushSender, urlSigner)); environment.addResource(keysController);
environment.addResource(messageController);
environment.addServlet(new MessageController(rateLimiters, deviceAuthenticator,
pushSender, accountsManager, federatedClientManager),
MessageController.PATH);
environment.addHealthCheck(new RedisHealthCheck(redisClient)); environment.addHealthCheck(new RedisHealthCheck(redisClient));
environment.addHealthCheck(new MemcacheHealthCheck(memcachedClient)); environment.addHealthCheck(new MemcacheHealthCheck(memcachedClient));

View File

@ -24,40 +24,43 @@ import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter; import com.yammer.metrics.core.Meter;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class DeviceAuthenticator implements Authenticator<BasicCredentials, Device> { public class AccountAuthenticator implements Authenticator<BasicCredentials, Account> {
private final Meter authenticationFailedMeter = Metrics.newMeter(DeviceAuthenticator.class, private final Meter authenticationFailedMeter = Metrics.newMeter(AccountAuthenticator.class,
"authentication", "failed", "authentication", "failed",
TimeUnit.MINUTES); TimeUnit.MINUTES);
private final Meter authenticationSucceededMeter = Metrics.newMeter(DeviceAuthenticator.class, private final Meter authenticationSucceededMeter = Metrics.newMeter(AccountAuthenticator.class,
"authentication", "succeeded", "authentication", "succeeded",
TimeUnit.MINUTES); TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(DeviceAuthenticator.class); private final Logger logger = LoggerFactory.getLogger(AccountAuthenticator.class);
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
public DeviceAuthenticator(AccountsManager accountsManager) { public AccountAuthenticator(AccountsManager accountsManager) {
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
} }
@Override @Override
public Optional<Device> authenticate(BasicCredentials basicCredentials) public Optional<Account> authenticate(BasicCredentials basicCredentials)
throws AuthenticationException throws AuthenticationException
{ {
AuthorizationHeader authorizationHeader;
try { try {
authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword()); AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
} catch (InvalidAuthorizationHeaderException iahe) { Optional<Account> account = accountsManager.get(authorizationHeader.getNumber());
if (!account.isPresent()) {
return Optional.absent(); return Optional.absent();
} }
Optional<Device> device = accountsManager.get(authorizationHeader.getNumber(), authorizationHeader.getDeviceId());
Optional<Device> device = account.get().getDevice(authorizationHeader.getDeviceId());
if (!device.isPresent()) { if (!device.isPresent()) {
return Optional.absent(); return Optional.absent();
@ -65,10 +68,14 @@ public class DeviceAuthenticator implements Authenticator<BasicCredentials, Devi
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) { if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
authenticationSucceededMeter.mark(); authenticationSucceededMeter.mark();
return device; account.get().setAuthenticatedDevice(device.get());
return account;
} }
authenticationFailedMeter.mark(); authenticationFailedMeter.mark();
return Optional.absent(); return Optional.absent();
} catch (InvalidAuthorizationHeaderException iahe) {
return Optional.absent();
}
} }
} }

View File

@ -32,8 +32,8 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -54,7 +54,6 @@ import javax.ws.rs.core.Response;
import java.io.IOException; import java.io.IOException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.Arrays;
@Path("/v1/accounts") @Path("/v1/accounts")
public class AccountController { public class AccountController {
@ -97,7 +96,7 @@ public class AccountController {
rateLimiters.getVoiceDestinationLimiter().validate(number); rateLimiters.getVoiceDestinationLimiter().validate(number);
break; break;
default: default:
throw new WebApplicationException(Response.status(415).build()); throw new WebApplicationException(Response.status(422).build());
} }
VerificationCode verificationCode = generateVerificationCode(); VerificationCode verificationCode = generateVerificationCode();
@ -137,14 +136,17 @@ public class AccountController {
} }
Device device = new Device(); Device device = new Device();
device.setNumber(number); device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey()); device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setDeviceId(0);
accounts.create(new Account(number, accountAttributes.getSupportsSms(), device)); Account account = new Account();
account.setNumber(number);
account.setSupportsSms(accountAttributes.getSupportsSms());
account.addDevice(device);
accounts.create(account);
pendingAccounts.remove(number); pendingAccounts.remove(number);
@ -161,36 +163,40 @@ public class AccountController {
@PUT @PUT
@Path("/gcm/") @Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Auth Device device, @Valid GcmRegistrationId registrationId) { public void setGcmRegistrationId(@Auth Account account, @Valid GcmRegistrationId registrationId) {
device.setApnRegistrationId(null); Device device = account.getAuthenticatedDevice().get();
device.setGcmRegistrationId(registrationId.getGcmRegistrationId()); device.setApnId(null);
accounts.update(device); device.setGcmId(registrationId.getGcmRegistrationId());
accounts.update(account);
} }
@Timed @Timed
@DELETE @DELETE
@Path("/gcm/") @Path("/gcm/")
public void deleteGcmRegistrationId(@Auth Device device) { public void deleteGcmRegistrationId(@Auth Account account) {
device.setGcmRegistrationId(null); Device device = account.getAuthenticatedDevice().get();
accounts.update(device); device.setGcmId(null);
accounts.update(account);
} }
@Timed @Timed
@PUT @PUT
@Path("/apn/") @Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Auth Device device, @Valid ApnRegistrationId registrationId) { public void setApnRegistrationId(@Auth Account account, @Valid ApnRegistrationId registrationId) {
device.setApnRegistrationId(registrationId.getApnRegistrationId()); Device device = account.getAuthenticatedDevice().get();
device.setGcmRegistrationId(null); device.setApnId(registrationId.getApnRegistrationId());
accounts.update(device); device.setGcmId(null);
accounts.update(account);
} }
@Timed @Timed
@DELETE @DELETE
@Path("/apn/") @Path("/apn/")
public void deleteApnRegistrationId(@Auth Device device) { public void deleteApnRegistrationId(@Auth Account account) {
device.setApnRegistrationId(null); Device device = account.getAuthenticatedDevice().get();
accounts.update(device); device.setApnId(null);
accounts.update(account);
} }
@Timed @Timed

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod; import com.amazonaws.HttpMethod;
import com.google.common.base.Optional;
import com.yammer.dropwizard.auth.Auth; import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed; import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -26,7 +27,7 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.UrlSigner; import org.whispersystems.textsecuregcm.util.UrlSigner;
@ -35,6 +36,7 @@ import javax.ws.rs.Path;
import javax.ws.rs.PathParam; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam; import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.io.IOException; import java.io.IOException;
@ -64,37 +66,38 @@ public class AttachmentController {
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response allocateAttachment(@Auth Device device) throws RateLimitExceededException { public AttachmentDescriptor allocateAttachment(@Auth Account account)
rateLimiters.getAttachmentLimiter().validate(device.getNumber()); throws RateLimitExceededException
{
if (account.isRateLimited()) {
rateLimiters.getAttachmentLimiter().validate(account.getNumber());
}
long attachmentId = generateAttachmentId(); long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT); URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT);
AttachmentDescriptor descriptor = new AttachmentDescriptor(attachmentId, url.toExternalForm());
return Response.ok().entity(descriptor).build(); return new AttachmentDescriptor(attachmentId, url.toExternalForm());
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{attachmentId}") @Path("/{attachmentId}")
public Response redirectToAttachment(@Auth Device device, public AttachmentUri redirectToAttachment(@Auth Account account,
@PathParam("attachmentId") long attachmentId, @PathParam("attachmentId") long attachmentId,
@QueryParam("relay") String relay) @QueryParam("relay") Optional<String> relay)
throws IOException
{ {
try { try {
URL url; if (!relay.isPresent()) {
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET));
if (relay == null) url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET); } else {
else url = federatedClientManager.getClient(relay).getSignedAttachmentUri(attachmentId); return new AttachmentUri(federatedClientManager.getClient(relay.get()).getSignedAttachmentUri(attachmentId));
}
return Response.ok().entity(new AttachmentUri(url)).build();
} catch (IOException e) {
logger.warn("No conectivity", e);
return Response.status(500).build();
} catch (NoSuchPeerException e) { } catch (NoSuchPeerException e) {
logger.info("No such peer: " + relay); logger.info("No such peer: " + relay);
return Response.status(404).build(); throw new WebApplicationException(Response.status(404).build());
} }
} }

View File

@ -26,7 +26,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader; import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException; import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager; import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
@ -68,13 +70,13 @@ public class DeviceController {
@GET @GET
@Path("/provisioning_code") @Path("/provisioning_code")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Device device) public VerificationCode createDeviceToken(@Auth Account account)
throws RateLimitExceededException throws RateLimitExceededException
{ {
rateLimiters.getVerifyLimiter().validate(device.getNumber()); //TODO: New limiter? rateLimiters.getVerifyLimiter().validate(account.getNumber()); //TODO: New limiter?
VerificationCode verificationCode = generateVerificationCode(); VerificationCode verificationCode = generateVerificationCode();
pendingDevices.store(device.getNumber(), verificationCode.getVerificationCode()); pendingDevices.store(account.getNumber(), verificationCode.getVerificationCode());
return verificationCode; return verificationCode;
} }
@ -84,12 +86,11 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Path("/{verification_code}") @Path("/{verification_code}")
public long verifyDeviceToken(@PathParam("verification_code") String verificationCode, public DeviceResponse verifyDeviceToken(@PathParam("verification_code") String verificationCode,
@HeaderParam("Authorization") String authorizationHeader, @HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes) @Valid AccountAttributes accountAttributes)
throws RateLimitExceededException throws RateLimitExceededException
{ {
Device device;
try { try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader); AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber(); String number = header.getNumber();
@ -105,24 +106,28 @@ public class DeviceController {
throw new WebApplicationException(Response.status(403).build()); throw new WebApplicationException(Response.status(403).build());
} }
device = new Device(); Optional<Account> account = accounts.get(number);
device.setNumber(number);
if (!account.isPresent()) {
throw new WebApplicationException(Response.status(403).build());
}
Device device = new Device();
device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey()); device.setSignalingKey(accountAttributes.getSignalingKey());
device.setSupportsSms(accountAttributes.getSupportsSms());
device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
accounts.provisionDevice(device); account.get().addDevice(device);
accounts.update(account.get());
pendingDevices.remove(number); pendingDevices.remove(number);
logger.debug("Stored new device device..."); return new DeviceResponse(device.getId());
} catch (InvalidAuthorizationHeaderException e) { } catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad Authorization Header", e); logger.info("Bad Authorization Header", e);
throw new WebApplicationException(Response.status(401).build()); throw new WebApplicationException(Response.status(401).build());
} }
return device.getDeviceId();
} }
@VisibleForTesting protected VerificationCode generateVerificationCode() { @VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@ -21,11 +21,11 @@ import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed; import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContactTokens; import org.whispersystems.textsecuregcm.entities.ClientContactTokens;
import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
@ -60,10 +60,10 @@ public class DirectoryController {
@GET @GET
@Path("/{token}") @Path("/{token}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response getTokenPresence(@Auth Device device, @PathParam("token") String token) public Response getTokenPresence(@Auth Account account, @PathParam("token") String token)
throws RateLimitExceededException throws RateLimitExceededException
{ {
rateLimiters.getContactsLimiter().validate(device.getNumber()); rateLimiters.getContactsLimiter().validate(account.getNumber());
try { try {
Optional<ClientContact> contact = directory.get(Base64.decodeWithoutPadding(token)); Optional<ClientContact> contact = directory.get(Base64.decodeWithoutPadding(token));
@ -82,10 +82,10 @@ public class DirectoryController {
@Path("/tokens") @Path("/tokens")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public ClientContacts getContactIntersection(@Auth Device device, @Valid ClientContactTokens contacts) public ClientContacts getContactIntersection(@Auth Account account, @Valid ClientContactTokens contacts)
throws RateLimitExceededException throws RateLimitExceededException
{ {
rateLimiters.getContactsLimiter().validate(device.getNumber(), contacts.getContacts().size()); rateLimiters.getContactsLimiter().validate(account.getNumber(), contacts.getContacts().size());
try { try {
List<byte[]> tokens = new LinkedList<>(); List<byte[]> tokens = new LinkedList<>();

View File

@ -16,9 +16,7 @@
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.google.protobuf.InvalidProtocolBufferException;
import com.yammer.dropwizard.auth.Auth; import com.yammer.dropwizard.auth.Auth;
import com.yammer.metrics.annotation.Timed; import com.yammer.metrics.annotation.Timed;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -27,40 +25,25 @@ import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.PathParam; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException; import java.io.IOException;
import java.net.URL;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.google.common.base.Preconditions.checkState;
@Path("/v1/federation") @Path("/v1/federation")
public class FederationController { public class FederationController {
@ -69,16 +52,20 @@ public class FederationController {
private static final int ACCOUNT_CHUNK_SIZE = 10000; private static final int ACCOUNT_CHUNK_SIZE = 10000;
private final PushSender pushSender;
private final Keys keys;
private final AccountsManager accounts; private final AccountsManager accounts;
private final UrlSigner urlSigner; private final AttachmentController attachmentController;
private final KeysController keysController;
private final MessageController messageController;
public FederationController(Keys keys, AccountsManager accounts, PushSender pushSender, UrlSigner urlSigner) { public FederationController(AccountsManager accounts,
this.keys = keys; AttachmentController attachmentController,
KeysController keysController,
MessageController messageController)
{
this.accounts = accounts; this.accounts = accounts;
this.pushSender = pushSender; this.attachmentController = attachmentController;
this.urlSigner = urlSigner; this.keysController = keysController;
this.messageController = messageController;
} }
@Timed @Timed
@ -87,82 +74,61 @@ public class FederationController {
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer, public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer,
@PathParam("attachmentId") long attachmentId) @PathParam("attachmentId") long attachmentId)
throws IOException
{ {
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET); return attachmentController.redirectToAttachment(new NonLimitedAccount("Unknown", peer.getName()),
return new AttachmentUri(url); attachmentId, Optional.<String>absent());
} }
@Timed @Timed
@GET @GET
@Path("/key/{number}") @Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKey(@Auth FederatedPeer peer, public PreKey getKey(@Auth FederatedPeer peer,
@PathParam("number") String number) @PathParam("number") String number)
throws IOException
{ {
Optional<Account> account = accounts.getAccount(number); try {
UnstructuredPreKeyList keyList = null; return keysController.get(new NonLimitedAccount("Unknown", peer.getName()), number, Optional.<String>absent());
if (account.isPresent()) } catch (RateLimitExceededException e) {
keyList = keys.get(number, account.get()); logger.warn("Rate limiting on federated channel", e);
if (!account.isPresent() || keyList.getKeys().isEmpty()) throw new IOException(e);
throw new WebApplicationException(Response.status(404).build()); }
return keyList; }
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKeys(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysController.getDeviceKey(new NonLimitedAccount("Unknown", peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
} }
@Timed @Timed
@PUT @PUT
@Path("/message") @Path("/messages/{source}/{destination}")
@Consumes(MediaType.APPLICATION_JSON) public void sendMessages(@Auth FederatedPeer peer,
@Produces(MediaType.APPLICATION_JSON) @PathParam("source") String source,
public MessageResponse relayMessage(@Auth FederatedPeer peer, @Valid List<RelayMessage> messages) @PathParam("destination") String destination,
@Valid IncomingMessageList messages)
throws IOException throws IOException
{ {
try { try {
Map<String, Set<Long>> localDestinations = new HashMap<>(); messages.setRelay(null);
for (RelayMessage message : messages) { messageController.sendMessage(new NonLimitedAccount(source, peer.getName()), destination, messages);
Set<Long> deviceIds = localDestinations.get(message.getDestination()); } catch (RateLimitExceededException e) {
if (deviceIds == null) { logger.warn("Rate limiting on federated channel", e);
deviceIds = new HashSet<>(); throw new IOException(e);
localDestinations.put(message.getDestination(), deviceIds);
}
deviceIds.add(message.getDestinationDeviceId());
}
List<Account> localAccounts = null;
try {
localAccounts = accounts.getAccountsForDevices(localDestinations);
} catch (MissingDevicesException e) {
return new MessageResponse(e.missingNumbers);
}
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
for (RelayMessage message : messages) {
Account destinationAccount = null;
for (Account account : localAccounts)
if (account.getNumber().equals(message.getDestination()))
destinationAccount= account;
checkState(destinationAccount != null);
Device device = destinationAccount.getDevice(message.getDestinationDeviceId());
OutgoingMessageSignal signal = OutgoingMessageSignal.parseFrom(message.getOutgoingMessageSignal())
.toBuilder()
.setRelay(peer.getName())
.build();
try {
pushSender.sendMessage(device, signal);
success.add(device.getBackwardsCompatibleNumberEncoding());
} catch (NoSuchUserException e) {
logger.info("No such user", e);
failure.add(device.getBackwardsCompatibleNumberEncoding());
}
}
return new MessageResponse(success, failure);
} catch (InvalidProtocolBufferException ipe) {
logger.warn("ProtoBuf", ipe);
throw new WebApplicationException(Response.status(400).build());
} }
} }
@ -181,15 +147,16 @@ public class FederationController {
public ClientContacts getUserTokens(@Auth FederatedPeer peer, public ClientContacts getUserTokens(@Auth FederatedPeer peer,
@PathParam("offset") int offset) @PathParam("offset") int offset)
{ {
List<Device> numberList = accounts.getAllMasterDevices(offset, ACCOUNT_CHUNK_SIZE); List<Account> accountList = accounts.getAll(offset, ACCOUNT_CHUNK_SIZE);
List<ClientContact> clientContacts = new LinkedList<>(); List<ClientContact> clientContacts = new LinkedList<>();
for (Device device : numberList) { for (Account account : accountList) {
byte[] token = Util.getContactToken(device.getNumber()); byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms()); ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
if (!device.isActive()) if (!account.isActive()) {
clientContact.setInactive(true); clientContact.setInactive(true);
}
clientContacts.add(clientContact); clientContacts.add(clientContact);
} }

View File

@ -29,7 +29,6 @@ import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid; import javax.validation.Valid;
@ -43,7 +42,6 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.util.List;
@Path("/v1/keys") @Path("/v1/keys")
public class KeysController { public class KeysController {
@ -52,46 +50,47 @@ public class KeysController {
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final Keys keys; private final Keys keys;
private final AccountsManager accountsManager;
private final FederatedClientManager federatedClientManager; private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accountsManager, public KeysController(RateLimiters rateLimiters, Keys keys,
FederatedClientManager federatedClientManager) FederatedClientManager federatedClientManager)
{ {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.keys = keys; this.keys = keys;
this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager; this.federatedClientManager = federatedClientManager;
} }
@Timed @Timed
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Device device, @Valid PreKeyList preKeys) { public void setKeys(@Auth Account account, @Valid PreKeyList preKeys) {
keys.store(device.getNumber(), device.getDeviceId(), preKeys.getLastResortKey(), preKeys.getKeys()); Device device = account.getAuthenticatedDevice().get();
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
} }
private List<PreKey> getKeys(Device device, String number, String relay) throws RateLimitExceededException @Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{ {
rateLimiters.getPreKeysLimiter().validate(device.getNumber() + "__" + number);
try { try {
UnstructuredPreKeyList keyList; if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
if (relay == null) {
Optional<Account> account = accountsManager.getAccount(number);
if (account.isPresent())
keyList = keys.get(number, account.get());
else
throw new WebApplicationException(Response.status(404).build());
} else {
keyList = federatedClientManager.getClient(relay).getKeys(number);
} }
if (keyList == null || keyList.getKeys().isEmpty()) throw new WebApplicationException(Response.status(404).build()); Optional<UnstructuredPreKeyList> results;
else return keyList.getKeys();
if (!relay.isPresent()) results = getLocalKeys(number, deviceId);
else results = federatedClientManager.getClient(relay.get()).getKeys(number, deviceId);
if (results.isPresent()) return results.get();
else throw new WebApplicationException(Response.status(404).build());
} catch (NoSuchPeerException e) { } catch (NoSuchPeerException e) {
logger.info("No peer: " + relay);
throw new WebApplicationException(Response.status(404).build()); throw new WebApplicationException(Response.status(404).build());
} }
} }
@ -100,15 +99,27 @@ public class KeysController {
@GET @GET
@Path("/{number}") @Path("/{number}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response get(@Auth Device device, public PreKey get(@Auth Account account,
@PathParam("number") String number, @PathParam("number") String number,
@QueryParam("multikeys") Optional<String> multikey, @QueryParam("relay") Optional<String> relay)
@QueryParam("relay") String relay)
throws RateLimitExceededException throws RateLimitExceededException
{ {
if (!multikey.isPresent()) UnstructuredPreKeyList results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
return Response.ok(getKeys(device, number, relay).get(0)).type(MediaType.APPLICATION_JSON).build(); return results.getKeys().get(0);
else }
return Response.ok(getKeys(device, number, relay)).type(MediaType.APPLICATION_JSON).build();
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceId) {
try {
if (deviceId.equals("*")) {
return keys.get(number);
}
Optional<PreKey> targetKey = keys.get(number, Long.parseLong(deviceId));
if (targetKey.isPresent()) return Optional.of(new UnstructuredPreKeyList(targetKey.get()));
else return Optional.absent();
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
} }
} }

View File

@ -16,383 +16,226 @@
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.yammer.dropwizard.auth.AuthenticationException; import com.yammer.dropwizard.auth.Auth;
import com.yammer.dropwizard.auth.basic.BasicCredentials; import com.yammer.metrics.annotation.Timed;
import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter;
import com.yammer.metrics.core.Timer;
import com.yammer.metrics.core.TimerContext;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.entities.MissingDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient; import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException; import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable; import javax.validation.Valid;
import javax.servlet.AsyncContext; import javax.ws.rs.Consumes;
import javax.servlet.http.HttpServlet; import javax.ws.rs.POST;
import javax.servlet.http.HttpServletRequest; import javax.ws.rs.PUT;
import javax.servlet.http.HttpServletResponse; import javax.ws.rs.Path;
import java.io.BufferedReader; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class MessageController extends HttpServlet { @Path("/v1/messages")
public class MessageController {
public static final String PATH = "/v1/messages/";
private final Meter successMeter = Metrics.newMeter(MessageController.class, "deliver_message", "success", TimeUnit.MINUTES);
private final Meter failureMeter = Metrics.newMeter(MessageController.class, "deliver_message", "failure", TimeUnit.MINUTES);
private final Timer timer = Metrics.newTimer(MessageController.class, "deliver_message_time", TimeUnit.MILLISECONDS, TimeUnit.MINUTES);
private final Logger logger = LoggerFactory.getLogger(MessageController.class); private final Logger logger = LoggerFactory.getLogger(MessageController.class);
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final DeviceAuthenticator deviceAuthenticator;
private final PushSender pushSender; private final PushSender pushSender;
private final FederatedClientManager federatedClientManager; private final FederatedClientManager federatedClientManager;
private final ObjectMapper objectMapper;
private final ExecutorService executor;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
public MessageController(RateLimiters rateLimiters, public MessageController(RateLimiters rateLimiters,
DeviceAuthenticator deviceAuthenticator,
PushSender pushSender, PushSender pushSender,
AccountsManager accountsManager, AccountsManager accountsManager,
FederatedClientManager federatedClientManager) FederatedClientManager federatedClientManager)
{ {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.deviceAuthenticator = deviceAuthenticator;
this.pushSender = pushSender; this.pushSender = pushSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.federatedClientManager = federatedClientManager; this.federatedClientManager = federatedClientManager;
this.objectMapper = new ObjectMapper();
this.executor = Executors.newFixedThreadPool(10);
} }
class LocalOrRemoteDevice { @Timed
Device device; @Path("/{destination}")
String relay, number; long deviceId; @PUT
LocalOrRemoteDevice(Device device) { @Consumes(MediaType.APPLICATION_JSON)
this.device = device; this.number = device.getNumber(); this.deviceId = device.getDeviceId(); public void sendMessage(@Auth Account source,
} @PathParam("destination") String destinationName,
LocalOrRemoteDevice(String relay, String number, long deviceId) { @Valid IncomingMessageList messages)
this.relay = relay; this.number = number; this.deviceId = deviceId; throws IOException, RateLimitExceededException
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) {
TimerContext timerContext = timer.time();
try {
Device sender = authenticate(req);
rateLimiters.getMessagesLimiter().validate(sender.getNumber());
handleAsyncDelivery(timerContext, req.startAsync(), sender, parseIncomingMessages(req));
} catch (AuthenticationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(401);
} catch (ValidationException e) {
failureMeter.mark();
timerContext.stop();
resp.setStatus(415);
} catch (IOException e) {
logger.warn("IOE", e);
failureMeter.mark();
timerContext.stop();
resp.setStatus(501);
} catch (RateLimitExceededException e) {
timerContext.stop();
failureMeter.mark();
resp.setStatus(413);
}
}
private void handleAsyncDelivery(final TimerContext timerContext,
final AsyncContext context,
final Device sender,
final IncomingMessageList messages)
{ {
executor.submit(new Runnable() { rateLimiters.getMessagesLimiter().validate(source.getNumber());
@Override
public void run() {
List<String> success = new LinkedList<>();
List<String> failure = new LinkedList<>();
HttpServletResponse response = (HttpServletResponse) context.getResponse();
try { try {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages; if (messages.getRelay() != null) sendLocalMessage(source, destinationName, messages);
try { else sendRelayMessage(source, destinationName, messages);
outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), messages.getMessages());
} catch (MissingDevicesException e) {
byte[] responseData = serializeResponse(new MessageResponse(e.missingNumbers));
response.setContentLength(responseData.length);
response.getOutputStream().write(responseData);
context.complete();
failureMeter.mark();
timerContext.stop();
return;
}
Map<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> relayMessages = new HashMap<>();
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : outgoingMessages) {
String relay = messagePair.first().relay;
if (Util.isEmpty(relay)) {
String encodedId = messagePair.first().device.getBackwardsCompatibleNumberEncoding();
try {
pushSender.sendMessage(messagePair.first().device, messagePair.second());
success.add(encodedId);
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
logger.debug("No such user", e); throw new WebApplicationException(Response.status(404).build());
failure.add(encodedId); } catch (MissingDevicesException e) {
} throw new WebApplicationException(Response.status(409)
} else { .entity(new MissingDevices(e.getMissingDevices()))
Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> messageSet = relayMessages.get(relay); .build());
if (messageSet == null) {
messageSet = new HashSet<>();
relayMessages.put(relay, messageSet);
}
messageSet.add(messagePair);
} }
} }
for (Map.Entry<String, Set<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>>> messagesForRelay : relayMessages.entrySet()) { @Timed
@Path("/")
@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public MessageResponse sendMessageLegacy(@Auth Account source, @Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
try { try {
FederatedClient client = federatedClientManager.getClient(messagesForRelay.getKey()); List<IncomingMessage> incomingMessages = messages.getMessages();
validateLegacyDestinations(incomingMessages);
List<RelayMessage> messages = new LinkedList<>(); messages.setRelay(incomingMessages.get(0).getRelay());
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> message : messagesForRelay.getValue()) { sendMessage(source, incomingMessages.get(0).getDestination(), messages);
messages.add(new RelayMessage(message.first().number,
message.first().deviceId, return new MessageResponse(new LinkedList<String>(), new LinkedList<String>());
message.second().toByteArray())); } catch (ValidationException e) {
throw new WebApplicationException(Response.status(422).build());
}
} }
MessageResponse relayResponse = client.sendMessages(messages); private void sendLocalMessage(Account source,
for (String string : relayResponse.getSuccess()) String destinationName,
success.add(string); IncomingMessageList messages)
for (String string : relayResponse.getFailure()) throws NoSuchUserException, MissingDevicesException, IOException
failure.add(string); {
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) {
sendLocalMessage(source, destination, destinationDevice.get(), incomingMessage);
}
}
}
private void sendLocalMessage(Account source,
Account destinationAccount,
Device destinationDevice,
IncomingMessage incomingMessage)
throws NoSuchUserException, IOException
{
try {
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
OutgoingMessageSignal.Builder messageBuilder = OutgoingMessageSignal.newBuilder();
messageBuilder.setType(incomingMessage.getType())
.setSource(source.getNumber())
.setTimestamp(System.currentTimeMillis());
if (messageBody.isPresent()) {
messageBuilder.setMessage(ByteString.copyFrom(messageBody.get()));
}
if (source.getRelay().isPresent()) {
messageBuilder.setRelay(source.getRelay().get());
}
pushSender.sendMessage(destinationAccount, destinationDevice, messageBuilder.build());
} catch (NotPushRegisteredException e) {
if (destinationDevice.isMaster()) throw new NoSuchUserException(e);
else logger.debug("Not registered", e);
} catch (TransientPushFailureException e) {
if (destinationDevice.isMaster()) throw new IOException(e);
else logger.debug("Transient failure", e);
}
}
private void sendRelayMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws IOException, NoSuchUserException
{
try {
FederatedClient client = federatedClientManager.getClient(messages.getRelay());
client.sendMessages(source.getNumber(), destinationName, messages);
} catch (NoSuchPeerException e) { } catch (NoSuchPeerException e) {
logger.info("No such peer", e); throw new NoSuchUserException(e);
for (Pair<LocalOrRemoteDevice, OutgoingMessageSignal> messagePair : messagesForRelay.getValue())
failure.add(messagePair.first().number);
} }
} }
byte[] responseData = serializeResponse(new MessageResponse(success, failure)); private Account getDestinationAccount(String destination)
response.setContentLength(responseData.length); throws NoSuchUserException
response.getOutputStream().write(responseData); {
context.complete(); Optional<Account> account = accountsManager.get(destination);
successMeter.mark();
} catch (IOException e) { if (!account.isPresent() || !account.get().isActive()) {
logger.warn("Async Handler", e); throw new NoSuchUserException(destination);
failureMeter.mark();
response.setStatus(501);
context.complete();
} catch (Exception e) {
logger.error("Unknown error sending message", e);
failureMeter.mark();
response.setStatus(500);
context.complete();
} }
timerContext.stop(); return account.get();
}
});
} }
@Nullable private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
private List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> getOutgoingMessageSignals(String sourceNumber,
List<IncomingMessage> incomingMessages)
throws MissingDevicesException throws MissingDevicesException
{ {
List<Pair<LocalOrRemoteDevice, OutgoingMessageSignal>> outgoingMessages = new LinkedList<>(); Set<Long> destinationDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Account> localAccounts = accountsManager.getAccountsForDevices(getLocalDestinations(incomingMessages)); for (IncomingMessage message : messages) {
destinationDeviceIds.add(message.getDestinationDeviceId());
Set<String> destinationNumbers = new HashSet<>();
for (IncomingMessage incoming : incomingMessages)
destinationNumbers.add(incoming.getDestination());
for (IncomingMessage incoming : incomingMessages) {
OutgoingMessageSignal.Builder outgoingMessage = OutgoingMessageSignal.newBuilder();
outgoingMessage.setType(incoming.getType());
outgoingMessage.setSource(sourceNumber);
byte[] messageBody = getMessageBody(incoming);
if (messageBody != null) {
outgoingMessage.setMessage(ByteString.copyFrom(messageBody));
} }
outgoingMessage.setTimestamp(System.currentTimeMillis()); for (Device device : account.getDevices()) {
if (!destinationDeviceIds.contains(device.getId())) {
for (String destination : destinationNumbers) { missingDeviceIds.add(device.getId());
if (!destination.equals(incoming.getDestination()))
outgoingMessage.addDestinations(destination);
}
LocalOrRemoteDevice device = null;
if (!Util.isEmpty(incoming.getRelay()))
device = new LocalOrRemoteDevice(incoming.getRelay(), incoming.getDestination(), incoming.getDestinationDeviceId());
else {
Account destination = null;
for (Account account : localAccounts) {
if (account.getNumber().equals(incoming.getDestination())) {
destination = account;
break;
} }
} }
if (destination != null) if (!missingDeviceIds.isEmpty()) {
device = new LocalOrRemoteDevice(destination.getDevice(incoming.getDestinationDeviceId())); throw new MissingDevicesException(missingDeviceIds);
}
if (device != null)
outgoingMessages.add(new Pair<>(device, outgoingMessage.build()));
}
return outgoingMessages;
}
// We use a map from number -> deviceIds here (instead of passing the list of messages to accountsManager) so that
// we can share as much code as possible with FederationController (which has RelayMessages, not IncomingMessages)
private Map<String, Set<Long>> getLocalDestinations(List<IncomingMessage> incomingMessages) {
Map<String, Set<Long>> localDestinations = new HashMap<>();
for (IncomingMessage incoming : incomingMessages) {
if (!Util.isEmpty(incoming.getRelay()))
continue;
Set<Long> deviceIds = localDestinations.get(incoming.getDestination());
if (deviceIds == null) {
deviceIds = new HashSet<>();
localDestinations.put(incoming.getDestination(), deviceIds);
}
deviceIds.add(incoming.getDestinationDeviceId());
}
return localDestinations;
}
private byte[] getMessageBody(IncomingMessage message) {
try {
return Base64.decode(message.getBody());
} catch (IOException ioe) {
ioe.printStackTrace();
return null;
} }
} }
private byte[] serializeResponse(MessageResponse response) throws IOException { private void validateLegacyDestinations(List<IncomingMessage> messages)
try { throws ValidationException
return objectMapper.writeValueAsBytes(response);
} catch (JsonProcessingException e) {
throw new IOException(e);
}
}
private IncomingMessageList parseIncomingMessages(HttpServletRequest request)
throws IOException, ValidationException
{ {
BufferedReader reader = request.getReader(); String destination = null;
StringBuilder content = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) { for (IncomingMessage message : messages) {
content.append(line); if (destination != null && !destination.equals(message.getDestination())) {
throw new ValidationException("Multiple account destinations!");
} }
IncomingMessageList messages = objectMapper.readValue(content.toString(), destination = message.getDestination();
IncomingMessageList.class); }
if (messages.getMessages() == null) {
throw new ValidationException();
} }
for (IncomingMessage message : messages.getMessages()) { private Optional<byte[]> getMessageBody(IncomingMessage message) {
if (message.getBody() == null) throw new ValidationException();
if (message.getDestination() == null) throw new ValidationException();
}
return messages;
}
private Device authenticate(HttpServletRequest request) throws AuthenticationException {
try { try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromFullHeader(request.getHeader("Authorization")); return Optional.of(Base64.decode(message.getBody()));
BasicCredentials credentials = new BasicCredentials(authorizationHeader.getNumber() + "." + authorizationHeader.getDeviceId(), } catch (IOException ioe) {
authorizationHeader.getPassword() ); logger.debug("Bad B64", ioe);
return Optional.absent();
Optional<Device> account = deviceAuthenticator.authenticate(credentials);
if (account.isPresent()) return account.get();
else throw new AuthenticationException("Bad credentials");
} catch (InvalidAuthorizationHeaderException e) {
throw new AuthenticationException(e);
} }
} }
// @Timed
// @POST
// @Consumes(MediaType.APPLICATION_JSON)
// @Produces(MediaType.APPLICATION_JSON)
// public MessageResponse sendMessage(@Auth Device sender, IncomingMessageList messages)
// throws IOException
// {
// List<String> success = new LinkedList<>();
// List<String> failure = new LinkedList<>();
// List<IncomingMessage> incomingMessages = messages.getMessages();
// List<OutgoingMessageSignal> outgoingMessages = getOutgoingMessageSignals(sender.getNumber(), incomingMessages);
//
// IterablePair<IncomingMessage, OutgoingMessageSignal> listPair = new IterablePair<>(incomingMessages, outgoingMessages);
//
// for (Pair<IncomingMessage, OutgoingMessageSignal> messagePair : listPair) {
// String destination = messagePair.first().getDestination();
// String relay = messagePair.first().getRelay();
//
// try {
// if (Util.isEmpty(relay)) sendLocalMessage(destination, messagePair.second());
// else sendRelayMessage(relay, destination, messagePair.second());
// success.add(destination);
// } catch (NoSuchUserException e) {
// logger.debug("No such user", e);
// failure.add(destination);
// }
// }
//
// return new MessageResponse(success, failure);
// }
} }

View File

@ -4,8 +4,13 @@ import java.util.List;
import java.util.Set; import java.util.Set;
public class MissingDevicesException extends Exception { public class MissingDevicesException extends Exception {
public Set<String> missingNumbers; private final List<Long> missingDevices;
public MissingDevicesException(Set<String> missingNumbers) {
this.missingNumbers = missingNumbers; public MissingDevicesException(List<Long> missingDevices) {
this.missingDevices = missingDevices;
}
public List<Long> getMissingDevices() {
return missingDevices;
} }
} }

View File

@ -18,4 +18,7 @@ package org.whispersystems.textsecuregcm.controllers;
public class ValidationException extends Exception { public class ValidationException extends Exception {
public ValidationException(String s) {
super(s);
}
} }

View File

@ -0,0 +1,13 @@
package org.whispersystems.textsecuregcm.entities;
public class CryptoEncodingException extends Exception {
public CryptoEncodingException(String s) {
super(s);
}
public CryptoEncodingException(Exception e) {
super(e);
}
}

View File

@ -0,0 +1,21 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
public class DeviceResponse {
@JsonProperty
private long deviceId;
@VisibleForTesting
public DeviceResponse() {}
public DeviceResponse(long deviceId) {
this.deviceId = deviceId;
}
public long getDeviceId() {
return deviceId;
}
}

View File

@ -51,7 +51,7 @@ public class EncryptedOutgoingMessage {
this.signalingKey = signalingKey; this.signalingKey = signalingKey;
} }
public String serialize() throws IOException { public String serialize() throws CryptoEncodingException {
byte[] plaintext = outgoingMessage.toByteArray(); byte[] plaintext = outgoingMessage.toByteArray();
SecretKeySpec cipherKey = getCipherKey (signalingKey); SecretKeySpec cipherKey = getCipherKey (signalingKey);
SecretKeySpec macKey = getMacKey(signalingKey); SecretKeySpec macKey = getMacKey(signalingKey);
@ -61,7 +61,7 @@ public class EncryptedOutgoingMessage {
} }
private byte[] getCiphertext(byte[] plaintext, SecretKeySpec cipherKey, SecretKeySpec macKey) private byte[] getCiphertext(byte[] plaintext, SecretKeySpec cipherKey, SecretKeySpec macKey)
throws IOException throws CryptoEncodingException
{ {
try { try {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
@ -85,31 +85,39 @@ public class EncryptedOutgoingMessage {
throw new AssertionError(e); throw new AssertionError(e);
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
logger.warn("Invalid Key", e); logger.warn("Invalid Key", e);
throw new IOException("Invalid key!"); throw new CryptoEncodingException("Invalid key!");
} }
} }
private SecretKeySpec getCipherKey(String signalingKey) throws IOException { private SecretKeySpec getCipherKey(String signalingKey) throws CryptoEncodingException {
try {
byte[] signalingKeyBytes = Base64.decode(signalingKey); byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] cipherKey = new byte[CIPHER_KEY_SIZE]; byte[] cipherKey = new byte[CIPHER_KEY_SIZE];
if (signalingKeyBytes.length < CIPHER_KEY_SIZE) if (signalingKeyBytes.length < CIPHER_KEY_SIZE)
throw new IOException("Signaling key too short!"); throw new CryptoEncodingException("Signaling key too short!");
System.arraycopy(signalingKeyBytes, 0, cipherKey, 0, cipherKey.length); System.arraycopy(signalingKeyBytes, 0, cipherKey, 0, cipherKey.length);
return new SecretKeySpec(cipherKey, "AES"); return new SecretKeySpec(cipherKey, "AES");
} catch (IOException e) {
throw new CryptoEncodingException(e);
}
} }
private SecretKeySpec getMacKey(String signalingKey) throws IOException { private SecretKeySpec getMacKey(String signalingKey) throws CryptoEncodingException {
try {
byte[] signalingKeyBytes = Base64.decode(signalingKey); byte[] signalingKeyBytes = Base64.decode(signalingKey);
byte[] macKey = new byte[MAC_KEY_SIZE]; byte[] macKey = new byte[MAC_KEY_SIZE];
if (signalingKeyBytes.length < CIPHER_KEY_SIZE + MAC_KEY_SIZE) if (signalingKeyBytes.length < CIPHER_KEY_SIZE + MAC_KEY_SIZE)
throw new IOException(("Signaling key too short!")); throw new CryptoEncodingException("Signaling key too short!");
System.arraycopy(signalingKeyBytes, CIPHER_KEY_SIZE, macKey, 0, macKey.length); System.arraycopy(signalingKeyBytes, CIPHER_KEY_SIZE, macKey, 0, macKey.length);
return new SecretKeySpec(macKey, "HmacSHA256"); return new SecretKeySpec(macKey, "HmacSHA256");
} catch (IOException e) {
throw new CryptoEncodingException(e);
}
} }
} }

View File

@ -29,9 +29,20 @@ public class IncomingMessageList {
@Valid @Valid
private List<IncomingMessage> messages; private List<IncomingMessage> messages;
@JsonProperty
private String relay;
public IncomingMessageList() {} public IncomingMessageList() {}
public List<IncomingMessage> getMessages() { public List<IncomingMessage> getMessages() {
return messages; return messages;
} }
public String getRelay() {
return relay;
}
public void setRelay(String relay) {
this.relay = relay;
}
} }

View File

@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
public class MissingDevices {
@JsonProperty
public List<Long> missingDevices;
public MissingDevices(List<Long> missingDevices) {
this.missingDevices = missingDevices;
}
}

View File

@ -35,7 +35,6 @@ public class PreKey {
private String number; private String number;
@JsonProperty @JsonProperty
@NotNull
private long deviceId; private long deviceId;
@JsonProperty @JsonProperty

View File

@ -23,14 +23,24 @@ import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList;
import java.util.List; import java.util.List;
public class UnstructuredPreKeyList { public class UnstructuredPreKeyList {
@JsonProperty @JsonProperty
@NotNull @NotNull
@Valid @Valid
private List<PreKey> keys; private List<PreKey> keys;
@VisibleForTesting
public UnstructuredPreKeyList() {}
public UnstructuredPreKeyList(PreKey preKey) {
this.keys = new LinkedList<PreKey>();
this.keys.add(preKey);
}
public UnstructuredPreKeyList(List<PreKey> preKeys) { public UnstructuredPreKeyList(List<PreKey> preKeys) {
this.keys = preKeys; this.keys = preKeys;
} }
@ -39,7 +49,8 @@ public class UnstructuredPreKeyList {
return keys; return keys;
} }
@VisibleForTesting public boolean equals(Object o) { @VisibleForTesting
public boolean equals(Object o) {
if (!(o instanceof UnstructuredPreKeyList) || if (!(o instanceof UnstructuredPreKeyList) ||
((UnstructuredPreKeyList) o).keys.size() != keys.size()) ((UnstructuredPreKeyList) o).keys.size() != keys.size())
return false; return false;

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.federation; package org.whispersystems.textsecuregcm.federation;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.Client; import com.sun.jersey.api.client.Client;
import com.sun.jersey.api.client.ClientHandlerException; import com.sun.jersey.api.client.ClientHandlerException;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
@ -34,14 +35,19 @@ import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageResponse; import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage; import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactory;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
@ -54,6 +60,7 @@ import java.security.SecureRandom;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.List; import java.util.List;
import java.util.Map;
public class FederatedClient { public class FederatedClient {
@ -61,8 +68,9 @@ public class FederatedClient {
private static final String USER_COUNT_PATH = "/v1/federation/user_count"; private static final String USER_COUNT_PATH = "/v1/federation/user_count";
private static final String USER_TOKENS_PATH = "/v1/federation/user_tokens/%d"; private static final String USER_TOKENS_PATH = "/v1/federation/user_tokens/%d";
private static final String RELAY_MESSAGE_PATH = "/v1/federation/message"; private static final String RELAY_MESSAGE_PATH = "/v1/federation/messages/%s/%s";
private static final String PREKEY_PATH = "/v1/federation/key/%s"; private static final String PREKEY_PATH = "/v1/federation/key/%s";
private static final String PREKEY_PATH_DEVICE = "/v1/federation/key/%s/%s";
private static final String ATTACHMENT_URI_PATH = "/v1/federation/attachment/%d"; private static final String ATTACHMENT_URI_PATH = "/v1/federation/attachment/%d";
private final FederatedPeer peer; private final FederatedPeer peer;
@ -98,15 +106,27 @@ public class FederatedClient {
} }
} }
public UnstructuredPreKeyList getKeys(String destination) { public Optional<PreKey> getKey(String destination) {
try { try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH, destination)); WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH, destination));
return resource.accept(MediaType.APPLICATION_JSON) return Optional.of(resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader) .header("Authorization", authorizationHeader)
.get(UnstructuredPreKeyList.class); .get(PreKey.class));
} catch (UniformInterfaceException | ClientHandlerException e) { } catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e); logger.warn("PreKey", e);
return null; return Optional.absent();
}
}
public Optional<UnstructuredPreKeyList> getKeys(String destination, String device) {
try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH_DEVICE, destination, device));
return Optional.of(resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.get(UnstructuredPreKeyList.class));
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e);
return Optional.absent();
} }
} }
@ -138,21 +158,19 @@ public class FederatedClient {
} }
} }
public MessageResponse sendMessages(List<RelayMessage> messages) public void sendMessages(String source, String destination, IncomingMessageList messages)
throws IOException throws IOException
{ {
try { try {
WebResource resource = client.resource(peer.getUrl()).path(RELAY_MESSAGE_PATH); WebResource resource = client.resource(peer.getUrl()).path(String.format(RELAY_MESSAGE_PATH, source, destination));
ClientResponse response = resource.type(MediaType.APPLICATION_JSON) ClientResponse response = resource.type(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader) .header("Authorization", authorizationHeader)
.entity(messages) .entity(messages)
.put(ClientResponse.class); .put(ClientResponse.class);
if (response.getStatus() != 200 && response.getStatus() != 204) { if (response.getStatus() != 200 && response.getStatus() != 204) {
throw new IOException("Bad response: " + response.getStatus()); throw new WebApplicationException(clientResponseToResponse(response));
} }
return response.getEntity(MessageResponse.class);
} catch (UniformInterfaceException | ClientHandlerException e) { } catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("sendMessage", e); logger.warn("sendMessage", e);
throw new IOException(e); throw new IOException(e);
@ -203,6 +221,19 @@ public class FederatedClient {
} }
} }
private Response clientResponseToResponse(ClientResponse r) {
Response.ResponseBuilder rb = Response.status(r.getStatus());
for (Map.Entry<String, List<String>> entry : r.getHeaders().entrySet()) {
for (String value : entry.getValue()) {
rb.header(entry.getKey(), value);
}
}
rb.entity(r.getEntityInputStream());
return rb.build();
}
public String getPeerName() { public String getPeerName() {
return peer.getName(); return peer.getName();
} }

View File

@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.federation;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.storage.Account;
public class NonLimitedAccount extends Account {
@JsonIgnore
private final String number;
@JsonIgnore
private final String relay;
public NonLimitedAccount(String number, String relay) {
this.number = number;
this.relay = relay;
}
public String getNumber() {
return number;
}
public boolean isRateLimited() {
return false;
}
public Optional<String> getRelay() {
return Optional.of(relay);
}
}

View File

@ -59,6 +59,7 @@ public class RateLimiters {
this.messagesLimiter = new RateLimiter(memcachedClient, "messages", this.messagesLimiter = new RateLimiter(memcachedClient, "messages",
config.getMessages().getBucketSize(), config.getMessages().getBucketSize(),
config.getMessages().getLeakRatePerMinute()); config.getMessages().getLeakRatePerMinute());
} }
public RateLimiter getMessagesLimiter() { public RateLimiter getMessagesLimiter() {

View File

@ -25,6 +25,7 @@ import com.yammer.metrics.core.Meter;
import org.bouncycastle.openssl.PEMReader; import org.bouncycastle.openssl.PEMReader;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -32,7 +33,6 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.net.MalformedURLException;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.KeyStoreException; import java.security.KeyStoreException;
@ -66,12 +66,12 @@ public class APNSender {
} }
public void sendMessage(String registrationId, EncryptedOutgoingMessage message) public void sendMessage(String registrationId, EncryptedOutgoingMessage message)
throws IOException throws TransientPushFailureException, NotPushRegisteredException
{ {
try { try {
if (!apnService.isPresent()) { if (!apnService.isPresent()) {
failure.mark(); failure.mark();
throw new IOException("APN access not configured!"); throw new TransientPushFailureException("APN access not configured!");
} }
String payload = APNS.newPayload() String payload = APNS.newPayload()
@ -83,12 +83,12 @@ public class APNSender {
apnService.get().push(registrationId, payload); apnService.get().push(registrationId, payload);
success.mark(); success.mark();
} catch (MalformedURLException mue) {
throw new AssertionError(mue);
} catch (NetworkIOException nioe) { } catch (NetworkIOException nioe) {
logger.warn("Network Error", nioe); logger.warn("Network Error", nioe);
failure.mark(); failure.mark();
throw new IOException("Error sending APN"); throw new TransientPushFailureException(nioe);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
} }
} }

View File

@ -22,7 +22,7 @@ import com.google.android.gcm.server.Result;
import com.google.android.gcm.server.Sender; import com.google.android.gcm.server.Sender;
import com.yammer.metrics.Metrics; import com.yammer.metrics.Metrics;
import com.yammer.metrics.core.Meter; import com.yammer.metrics.core.Meter;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import java.io.IOException; import java.io.IOException;
@ -40,8 +40,9 @@ public class GCMSender {
} }
public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage) public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage)
throws IOException, NoSuchUserException throws NotPushRegisteredException, TransientPushFailureException
{ {
try {
Message gcmMessage = new Message.Builder().addData("type", "message") Message gcmMessage = new Message.Builder().addData("type", "message")
.addData("message", outgoingMessage.serialize()) .addData("message", outgoingMessage.serialize())
.build(); .build();
@ -54,10 +55,15 @@ public class GCMSender {
} else { } else {
failure.mark(); failure.mark();
if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) { if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) {
throw new NoSuchUserException("User no longer registered with GCM."); throw new NotPushRegisteredException("Device no longer registered with GCM.");
} else { } else {
throw new IOException("GCM Failed: " + result.getErrorCodeName()); throw new TransientPushFailureException("GCM Failed: " + result.getErrorCodeName());
} }
} }
} catch (IOException e) {
throw new TransientPushFailureException(e);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
} }
} }

View File

@ -0,0 +1,11 @@
package org.whispersystems.textsecuregcm.push;
public class NotPushRegisteredException extends Exception {
public NotPushRegisteredException(String s) {
super(s);
}
public NotPushRegisteredException(Exception e) {
super(e);
}
}

View File

@ -16,35 +16,28 @@
*/ */
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import com.google.common.base.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration; import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager; import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.util.Pair;
import java.io.IOException; import java.io.IOException;
import java.security.KeyStoreException; import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class PushSender { public class PushSender {
private final Logger logger = LoggerFactory.getLogger(PushSender.class); private final Logger logger = LoggerFactory.getLogger(PushSender.class);
private final AccountsManager accounts; private final AccountsManager accounts;
private final GCMSender gcmSender; private final GCMSender gcmSender;
private final APNSender apnSender; private final APNSender apnSender;
private final StoredMessageManager storedMessageManager; private final StoredMessageManager storedMessageManager;
@ -52,56 +45,65 @@ public class PushSender {
public PushSender(GcmConfiguration gcmConfiguration, public PushSender(GcmConfiguration gcmConfiguration,
ApnConfiguration apnConfiguration, ApnConfiguration apnConfiguration,
StoredMessageManager storedMessageManager, StoredMessageManager storedMessageManager,
AccountsManager accounts, AccountsManager accounts)
DirectoryManager directory)
throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{ {
this.accounts = accounts; this.accounts = accounts;
this.storedMessageManager = storedMessageManager; this.storedMessageManager = storedMessageManager;
this.gcmSender = new GCMSender(gcmConfiguration.getApiKey()); this.gcmSender = new GCMSender(gcmConfiguration.getApiKey());
this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey()); this.apnSender = new APNSender(apnConfiguration.getCertificate(), apnConfiguration.getKey());
} }
public void sendMessage(Device device, MessageProtos.OutgoingMessageSignal outgoingMessage) public void sendMessage(Account account, Device device, MessageProtos.OutgoingMessageSignal outgoingMessage)
throws IOException, NoSuchUserException throws NotPushRegisteredException, TransientPushFailureException
{ {
String signalingKey = device.getSignalingKey(); String signalingKey = device.getSignalingKey();
EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey); EncryptedOutgoingMessage message = new EncryptedOutgoingMessage(outgoingMessage, signalingKey);
if (device.getGcmRegistrationId() != null) sendGcmMessage(device, message); if (device.getGcmId() != null) sendGcmMessage(account, device, message);
else if (device.getApnRegistrationId() != null) sendApnMessage(device, message); else if (device.getApnId() != null) sendApnMessage(account, device, message);
else if (device.getFetchesMessages()) storeFetchedMessage(device, message); else if (device.getFetchesMessages()) storeFetchedMessage(device, message);
else throw new NoSuchUserException("No push identifier!"); else throw new NotPushRegisteredException("No delivery possible!");
} }
private void sendGcmMessage(Device device, EncryptedOutgoingMessage outgoingMessage) private void sendGcmMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws IOException, NoSuchUserException throws NotPushRegisteredException, TransientPushFailureException
{ {
try { try {
String canonicalId = gcmSender.sendMessage(device.getGcmRegistrationId(), String canonicalId = gcmSender.sendMessage(device.getGcmId(), outgoingMessage);
outgoingMessage);
if (canonicalId != null) { if (canonicalId != null) {
device.setGcmRegistrationId(canonicalId); device.setGcmId(canonicalId);
accounts.update(device); accounts.update(account);
} }
} catch (NoSuchUserException e) { } catch (NotPushRegisteredException e) {
logger.debug("No Such User", e); logger.debug("No Such User", e);
device.setGcmRegistrationId(null); device.setGcmId(null);
accounts.update(device); accounts.update(account);
throw new NoSuchUserException("User no longer exists in GCM."); throw new NotPushRegisteredException(e);
} }
} }
private void sendApnMessage(Device device, EncryptedOutgoingMessage outgoingMessage) private void sendApnMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws IOException throws TransientPushFailureException, NotPushRegisteredException
{ {
apnSender.sendMessage(device.getApnRegistrationId(), outgoingMessage); try {
apnSender.sendMessage(device.getApnId(), outgoingMessage);
} catch (NotPushRegisteredException e) {
device.setApnId(null);
accounts.update(account);
throw new NotPushRegisteredException(e);
}
} }
private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage) throws IOException { private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws NotPushRegisteredException
{
try {
storedMessageManager.storeMessage(device, outgoingMessage); storedMessageManager.storeMessage(device, outgoingMessage);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
} }
} }

View File

@ -0,0 +1,11 @@
package org.whispersystems.textsecuregcm.push;
public class TransientPushFailureException extends Exception {
public TransientPushFailureException(String s) {
super(s);
}
public TransientPushFailureException(Exception e) {
super(e);
}
}

View File

@ -17,35 +17,43 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import com.fasterxml.jackson.annotation.JsonIgnore;
import org.whispersystems.textsecuregcm.util.Util; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Optional;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collection; import java.util.LinkedList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
public class Account implements Serializable { public class Account implements Serializable {
private String number;
private boolean supportsSms;
private Map<Long, Device> devices = new HashMap<>();
private Account(String number, boolean supportsSms) { public static final int MEMCACHE_VERION = 2;
@JsonProperty
private String number;
@JsonProperty
private boolean supportsSms;
@JsonProperty
private List<Device> devices = new LinkedList<>();
@JsonIgnore
private Optional<Device> authenticatedDevice;
public Account() {}
public Account(String number, boolean supportsSms) {
this.number = number; this.number = number;
this.supportsSms = supportsSms; this.supportsSms = supportsSms;
} }
public Account(String number, boolean supportsSms, Device onlyDevice) { public Optional<Device> getAuthenticatedDevice() {
this(number, supportsSms); return authenticatedDevice;
addDevice(onlyDevice);
} }
public Account(String number, boolean supportsSms, List<Device> devices) { public void setAuthenticatedDevice(Device device) {
this(number, supportsSms); this.authenticatedDevice = Optional.of(device);
for (Device device : devices)
addDevice(device);
} }
public void setNumber(String number) { public void setNumber(String number) {
@ -64,30 +72,55 @@ public class Account implements Serializable {
this.supportsSms = supportsSms; this.supportsSms = supportsSms;
} }
public void addDevice(Device device) {
this.devices.add(device);
}
public void setDevices(List<Device> devices) {
this.devices = devices;
}
public List<Device> getDevices() {
return devices;
}
public Optional<Device> getMasterDevice() {
return getDevice(Device.MASTER_ID);
}
public Optional<Device> getDevice(long deviceId) {
for (Device device : devices) {
if (device.getId() == deviceId) {
return Optional.of(device);
}
}
return Optional.absent();
}
public boolean isActive() { public boolean isActive() {
Device masterDevice = devices.get((long) 1); return
return masterDevice != null && masterDevice.isActive(); getMasterDevice().isPresent() &&
getMasterDevice().get().isActive();
} }
public Collection<Device> getDevices() { public long getNextDeviceId() {
return devices.values(); long highestDevice = Device.MASTER_ID;
for (Device device : devices) {
if (device.getId() > highestDevice) {
highestDevice = device.getId();
}
} }
public Device getDevice(long destinationDeviceId) { return highestDevice + 1;
return devices.get(destinationDeviceId);
} }
public boolean hasAllDeviceIds(Set<Long> deviceIds) { public boolean isRateLimited() {
if (devices.size() != deviceIds.size())
return false;
for (long deviceId : devices.keySet()) {
if (!deviceIds.contains(deviceId))
return false;
}
return true; return true;
} }
public void addDevice(Device device) { public Optional<String> getRelay() {
devices.put(device.getDeviceId(), device); return Optional.absent();
} }
} }

View File

@ -16,6 +16,10 @@
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.skife.jdbi.v2.SQLStatement; import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext; import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel; import org.skife.jdbi.v2.TransactionIsolationLevel;
@ -28,10 +32,9 @@ import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate; import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction; import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper; import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.sqlobject.stringtemplate.UseStringTemplate3StatementLocator;
import org.skife.jdbi.v2.tweak.ResultSetMapper; import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.skife.jdbi.v2.unstable.BindIn;
import java.io.IOException;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType; import java.lang.annotation.ElementType;
import java.lang.annotation.Retention; import java.lang.annotation.Retention;
@ -39,91 +42,63 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Set;
@UseStringTemplate3StatementLocator
public abstract class Accounts { public abstract class Accounts {
public static final String ID = "id"; private static final String ID = "id";
public static final String NUMBER = "number"; private static final String NUMBER = "number";
public static final String DEVICE_ID = "device_id"; private static final String DATA = "data";
public static final String AUTH_TOKEN = "auth_token";
public static final String SALT = "salt";
public static final String SIGNALING_KEY = "signaling_key";
public static final String GCM_ID = "gcm_id";
public static final String APN_ID = "apn_id";
public static final String FETCHES_MESSAGES = "fetches_messages";
public static final String SUPPORTS_SMS = "supports_sms";
@SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DEVICE_ID + ", " + AUTH_TOKEN + ", " + private static final ObjectMapper mapper = new ObjectMapper();
SALT + ", " + SIGNALING_KEY + ", " + FETCHES_MESSAGES + ", " +
GCM_ID + ", " + APN_ID + ", " + SUPPORTS_SMS + ") " +
"VALUES (:number, :device_id, :auth_token, :salt, :signaling_key, :fetches_messages, :gcm_id, :apn_id, :supports_sms)")
@GetGeneratedKeys
abstract long insertStep(@AccountBinder Device device);
@SqlQuery("SELECT " + DEVICE_ID + " FROM accounts WHERE " + NUMBER + " = :number ORDER BY " + DEVICE_ID + " DESC LIMIT 1 FOR UPDATE") static {
abstract long getHighestDeviceId(@Bind("number") String number); mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public long insert(@AccountBinder Device device) {
device.setDeviceId(getHighestDeviceId(device.getNumber()) + 1);
return insertStep(device);
} }
@SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number RETURNING id") @SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
abstract void removeAccountsByNumber(@Bind("number") String number); @GetGeneratedKeys
abstract long insertStep(@AccountBinder Account account);
@SqlUpdate("UPDATE accounts SET " + AUTH_TOKEN + " = :auth_token, " + SALT + " = :salt, " + @SqlUpdate("DELETE FROM accounts WHERE " + NUMBER + " = :number")
SIGNALING_KEY + " = :signaling_key, " + GCM_ID + " = :gcm_id, " + APN_ID + " = :apn_id, " + abstract void removeAccount(@Bind("number") String number);
FETCHES_MESSAGES + " = :fetches_messages, " + SUPPORTS_SMS + " = :supports_sms " +
"WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id")
abstract void update(@AccountBinder Device device);
@Mapper(DeviceMapper.class) @SqlUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + NUMBER + " = :number")
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number AND " + DEVICE_ID + " = :device_id") abstract void update(@AccountBinder Account account);
abstract Device get(@Bind("number") String number, @Bind("device_id") long deviceId);
@Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
abstract Account get(@Bind("number") String number);
@SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts") @SqlQuery("SELECT COUNT(DISTINCT " + NUMBER + ") from accounts")
abstract long getNumberCount(); abstract long getCount();
@Mapper(DeviceMapper.class) @Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1 OFFSET :offset LIMIT :limit") @SqlQuery("SELECT * FROM accounts OFFSET :offset LIMIT :limit")
abstract List<Device> getAllMasterDevices(@Bind("offset") int offset, @Bind("limit") int length); abstract List<Account> getAll(@Bind("offset") int offset, @Bind("limit") int length);
@Mapper(DeviceMapper.class) @Mapper(AccountMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + DEVICE_ID + " = 1") @SqlQuery("SELECT * FROM accounts")
public abstract Iterator<Device> getAllMasterDevices(); public abstract Iterator<Account> getAll();
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " = :number")
public abstract List<Device> getAllByNumber(@Bind("number") String number);
@Mapper(DeviceMapper.class)
@SqlQuery("SELECT * FROM accounts WHERE " + NUMBER + " IN ( <numbers> )")
public abstract List<Device> getAllByNumbers(@BindIn("numbers") List<String> numbers);
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public long insertClearingNumber(Device device) { public long create(Account account) {
removeAccountsByNumber(device.getNumber()); removeAccount(account.getNumber());
device.setDeviceId(getHighestDeviceId(device.getNumber()) + 1); return insertStep(account);
return insertStep(device);
} }
public static class DeviceMapper implements ResultSetMapper<Device> { public static class AccountMapper implements ResultSetMapper<Account> {
@Override @Override
public Device map(int i, ResultSet resultSet, StatementContext statementContext) public Account map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException throws SQLException
{ {
return new Device(resultSet.getLong(ID), resultSet.getString(NUMBER), resultSet.getLong(DEVICE_ID), try {
resultSet.getString(AUTH_TOKEN), resultSet.getString(SALT), return mapper.readValue(resultSet.getString(DATA), Account.class);
resultSet.getString(SIGNALING_KEY), resultSet.getString(GCM_ID), } catch (IOException e) {
resultSet.getString(APN_ID), throw new SQLException(e);
resultSet.getInt(SUPPORTS_SMS) == 1, resultSet.getInt(FETCHES_MESSAGES) == 1); }
} }
} }
@ -134,23 +109,20 @@ public abstract class Accounts {
public static class AccountBinderFactory implements BinderFactory { public static class AccountBinderFactory implements BinderFactory {
@Override @Override
public Binder build(Annotation annotation) { public Binder build(Annotation annotation) {
return new Binder<AccountBinder, Device>() { return new Binder<AccountBinder, Account>() {
@Override @Override
public void bind(SQLStatement<?> sql, public void bind(SQLStatement<?> sql,
AccountBinder accountBinder, AccountBinder accountBinder,
Device device) Account account)
{ {
sql.bind(ID, device.getId()); try {
sql.bind(NUMBER, device.getNumber()); String serialized = mapper.writeValueAsString(account);
sql.bind(DEVICE_ID, device.getDeviceId());
sql.bind(AUTH_TOKEN, device.getAuthenticationCredentials() sql.bind(NUMBER, account.getNumber());
.getHashedAuthenticationToken()); sql.bind(DATA, serialized);
sql.bind(SALT, device.getAuthenticationCredentials().getSalt()); } catch (JsonProcessingException e) {
sql.bind(SIGNALING_KEY, device.getSignalingKey()); throw new IllegalArgumentException(e);
sql.bind(GCM_ID, device.getGcmRegistrationId()); }
sql.bind(APN_ID, device.getApnRegistrationId());
sql.bind(SUPPORTS_SMS, device.getSupportsSms() ? 1 : 0);
sql.bind(FETCHES_MESSAGES, device.getFetchesMessages() ? 1 : 0);
} }
}; };
} }

View File

@ -49,131 +49,67 @@ public class AccountsManager {
} }
public long getCount() { public long getCount() {
return accounts.getNumberCount(); return accounts.getCount();
} }
public List<Device> getAllMasterDevices(int offset, int length) { public List<Account> getAll(int offset, int length) {
return accounts.getAllMasterDevices(offset, length); return accounts.getAll(offset, length);
} }
public Iterator<Device> getAllMasterDevices() { public Iterator<Account> getAll() {
return accounts.getAllMasterDevices(); return accounts.getAll();
} }
/** Creates a new Account (WITH ONE DEVICE), clearing all existing devices on the given number */
public void create(Account account) { public void create(Account account) {
Device device = account.getDevices().iterator().next(); accounts.create(account);
long id = accounts.insertClearingNumber(device);
device.setId(id);
if (memcachedClient != null) { if (memcachedClient != null) {
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device); memcachedClient.set(getKey(account.getNumber()), 0, account);
} }
updateDirectory(device); updateDirectory(account);
} }
/** Creates a new Device for an existing Account */ public void update(Account account) {
public void provisionDevice(Device device) { if (memcachedClient != null) {
long id = accounts.insert(device); memcachedClient.set(getKey(account.getNumber()), 0, account);
device.setId(id); }
accounts.update(account);
updateDirectory(account);
}
public Optional<Account> get(String number) {
Account account = null;
if (memcachedClient != null) { if (memcachedClient != null) {
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device); account = (Account)memcachedClient.get(getKey(number));
} }
updateDirectory(device); if (account == null) {
} account = accounts.get(number);
public void update(Device device) { if (account != null && memcachedClient != null) {
if (memcachedClient != null) { memcachedClient.set(getKey(number), 0, account);
memcachedClient.set(getKey(device.getNumber(), device.getDeviceId()), 0, device);
}
accounts.update(device);
updateDirectory(device);
}
public Optional<Device> get(String number, long deviceId) {
Device device = null;
if (memcachedClient != null) {
device = (Device)memcachedClient.get(getKey(number, deviceId));
}
if (device == null) {
device = accounts.get(number, deviceId);
if (device != null && memcachedClient != null) {
memcachedClient.set(getKey(number, deviceId), 0, device);
} }
} }
if (device != null) return Optional.of(device); if (account != null) return Optional.of(account);
else return Optional.absent(); else return Optional.absent();
} }
public Optional<Account> getAccount(String number) { private void updateDirectory(Account account) {
List<Device> devices = accounts.getAllByNumber(number); if (account.isActive()) {
if (devices.isEmpty()) byte[] token = Util.getContactToken(account.getNumber());
return Optional.absent(); ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
return Optional.of(new Account(number, devices.get(0).getSupportsSms(), devices));
}
private List<Account> getAllAccounts(List<String> numbers) {
List<Device> devices = accounts.getAllByNumbers(numbers);
List<Account> accounts = new LinkedList<>();
for (Device device : devices) {
Account deviceAccount = null;
for (Account account : accounts) {
if (account.getNumber().equals(device.getNumber())) {
deviceAccount = account;
break;
}
}
if (deviceAccount == null) {
deviceAccount = new Account(device.getNumber(), false, device);
accounts.add(deviceAccount);
} else {
deviceAccount.addDevice(device);
}
if (device.getDeviceId() == 1)
deviceAccount.setSupportsSms(device.getSupportsSms());
}
return accounts;
}
public List<Account> getAccountsForDevices(Map<String, Set<Long>> destinations) throws MissingDevicesException {
Set<String> numbersMissingDevices = new HashSet<>(destinations.keySet());
List<Account> localAccounts = getAllAccounts(new LinkedList<>(destinations.keySet()));
for (Account account : localAccounts){
if (account.hasAllDeviceIds(destinations.get(account.getNumber())))
numbersMissingDevices.remove(account.getNumber());
}
if (!numbersMissingDevices.isEmpty())
throw new MissingDevicesException(numbersMissingDevices);
return localAccounts;
}
private void updateDirectory(Device device) {
if (device.getDeviceId() != 1)
return;
if (device.isActive()) {
byte[] token = Util.getContactToken(device.getNumber());
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms());
directory.add(clientContact); directory.add(clientContact);
} else { } else {
directory.remove(device.getNumber()); directory.remove(account.getNumber());
} }
} }
private String getKey(String number, long accountId) { private String getKey(String number) {
return Device.class.getSimpleName() + Device.MEMCACHE_VERION + number + accountId; return Account.class.getSimpleName() + Account.MEMCACHE_VERION + number;
} }
} }

View File

@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -24,97 +25,58 @@ import java.io.Serializable;
public class Device implements Serializable { public class Device implements Serializable {
public static final int MEMCACHE_VERION = 1; public static final long MASTER_ID = 1;
@JsonProperty
private long id; private long id;
private String number;
private long deviceId; @JsonProperty
private String hashedAuthenticationToken; private String authToken;
@JsonProperty
private String salt; private String salt;
@JsonProperty
private String signalingKey; private String signalingKey;
/**
* In order for us to tell a client that an account is "inactive" (ie go use SMS for transport), we check that all @JsonProperty
* non-fetching Accounts don't have push registrations. In this way, we can ensure that we have some form of transport private String gcmId;
* available for all Accounts on all "active" numbers.
*/ @JsonProperty
private String gcmRegistrationId; private String apnId;
private String apnRegistrationId;
private boolean supportsSms; @JsonProperty
private boolean fetchesMessages; private boolean fetchesMessages;
public Device() {} public Device() {}
public Device(long id, String number, long deviceId, String hashedAuthenticationToken, String salt, public Device(long id, String authToken, String salt,
String signalingKey, String gcmRegistrationId, String apnRegistrationId, String signalingKey, String gcmId, String apnId,
boolean supportsSms, boolean fetchesMessages) boolean fetchesMessages)
{ {
this.id = id; this.id = id;
this.number = number; this.authToken = authToken;
this.deviceId = deviceId;
this.hashedAuthenticationToken = hashedAuthenticationToken;
this.salt = salt; this.salt = salt;
this.signalingKey = signalingKey; this.signalingKey = signalingKey;
this.gcmRegistrationId = gcmRegistrationId; this.gcmId = gcmId;
this.apnRegistrationId = apnRegistrationId; this.apnId = apnId;
this.supportsSms = supportsSms;
this.fetchesMessages = fetchesMessages; this.fetchesMessages = fetchesMessages;
} }
public String getApnRegistrationId() { public String getApnId() {
return apnRegistrationId; return apnId;
} }
public void setApnRegistrationId(String apnRegistrationId) { public void setApnId(String apnId) {
this.apnRegistrationId = apnRegistrationId; this.apnId = apnId;
} }
public String getGcmRegistrationId() { public String getGcmId() {
return gcmRegistrationId; return gcmId;
} }
public void setGcmRegistrationId(String gcmRegistrationId) { public void setGcmId(String gcmId) {
this.gcmRegistrationId = gcmRegistrationId; this.gcmId = gcmId;
}
public void setNumber(String number) {
this.number = number;
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public void setDeviceId(long deviceId) {
this.deviceId = deviceId;
}
public void setAuthenticationCredentials(AuthenticationCredentials credentials) {
this.hashedAuthenticationToken = credentials.getHashedAuthenticationToken();
this.salt = credentials.getSalt();
}
public AuthenticationCredentials getAuthenticationCredentials() {
return new AuthenticationCredentials(hashedAuthenticationToken, salt);
}
public String getSignalingKey() {
return signalingKey;
}
public void setSignalingKey(String signalingKey) {
this.signalingKey = signalingKey;
}
public boolean getSupportsSms() {
return supportsSms;
}
public void setSupportsSms(boolean supportsSms) {
this.supportsSms = supportsSms;
} }
public long getId() { public long getId() {
@ -125,8 +87,25 @@ public class Device implements Serializable {
this.id = id; this.id = id;
} }
public void setAuthenticationCredentials(AuthenticationCredentials credentials) {
this.authToken = credentials.getHashedAuthenticationToken();
this.salt = credentials.getSalt();
}
public AuthenticationCredentials getAuthenticationCredentials() {
return new AuthenticationCredentials(authToken, salt);
}
public String getSignalingKey() {
return signalingKey;
}
public void setSignalingKey(String signalingKey) {
this.signalingKey = signalingKey;
}
public boolean isActive() { public boolean isActive() {
return fetchesMessages || !Util.isEmpty(getApnRegistrationId()) || !Util.isEmpty(getGcmRegistrationId()); return fetchesMessages || !Util.isEmpty(getApnId()) || !Util.isEmpty(getGcmId());
} }
public boolean getFetchesMessages() { public boolean getFetchesMessages() {
@ -137,7 +116,7 @@ public class Device implements Serializable {
this.fetchesMessages = fetchesMessages; this.fetchesMessages = fetchesMessages;
} }
public String getBackwardsCompatibleNumberEncoding() { public boolean isMaster() {
return deviceId == 1 ? number : (number + "." + deviceId); return getId() == MASTER_ID;
} }
} }

View File

@ -16,6 +16,7 @@
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.common.base.Optional;
import org.skife.jdbi.v2.SQLStatement; import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext; import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.TransactionIsolationLevel; import org.skife.jdbi.v2.TransactionIsolationLevel;
@ -62,8 +63,11 @@ public abstract class Keys {
@Mapper(PreKeyMapper.class) @Mapper(PreKeyMapper.class)
abstract PreKey retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId); abstract PreKey retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY key_id ASC FOR UPDATE")
abstract List<PreKey> retrieveFirst(@Bind("number") String number);
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, PreKey lastResortKey, List<PreKey> keys) { public void store(String number, long deviceId, List<PreKey> keys, PreKey lastResortKey) {
for (PreKey key : keys) { for (PreKey key : keys) {
key.setNumber(number); key.setNumber(number);
key.setDeviceId(deviceId); key.setDeviceId(deviceId);
@ -79,20 +83,31 @@ public abstract class Keys {
} }
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public UnstructuredPreKeyList get(String number, Account account) { public Optional<PreKey> get(String number, long deviceId) {
List<PreKey> preKeys = new LinkedList<>(); PreKey preKey = retrieveFirst(number, deviceId);
for (Device device : account.getDevices()) {
PreKey preKey = retrieveFirst(number, device.getDeviceId());
if (preKey != null)
preKeys.add(preKey);
}
for (PreKey preKey : preKeys) { if (preKey != null && !preKey.isLastResort()) {
if (!preKey.isLastResort())
removeKey(preKey.getId()); removeKey(preKey.getId());
} }
return new UnstructuredPreKeyList(preKeys); if (preKey != null) return Optional.of(preKey);
else return Optional.absent();
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<UnstructuredPreKeyList> get(String number) {
List<PreKey> preKeys = retrieveFirst(number);
if (preKeys != null) {
for (PreKey preKey : preKeys) {
if (!preKey.isLastResort()) {
removeKey(preKey.getId());
}
}
}
if (preKeys != null) return Optional.of(new UnstructuredPreKeyList(preKeys));
else return Optional.absent();
} }
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class) @BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)

View File

@ -20,7 +20,7 @@ import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.SqlQuery; import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate; import org.skife.jdbi.v2.sqlobject.SqlUpdate;
public interface PendingDeviceRegistrations { public interface PendingDevices {
@SqlUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code WHERE number = :number RETURNING *) " + @SqlUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code WHERE number = :number RETURNING *) " +
"INSERT INTO pending_devices (number, verification_code) SELECT :number, :verification_code WHERE NOT EXISTS (SELECT * FROM upsert)") "INSERT INTO pending_devices (number, verification_code) SELECT :number, :verification_code WHERE NOT EXISTS (SELECT * FROM upsert)")

View File

@ -23,10 +23,10 @@ public class PendingDevicesManager {
private static final String MEMCACHE_PREFIX = "pending_devices"; private static final String MEMCACHE_PREFIX = "pending_devices";
private final PendingDeviceRegistrations pendingDevices; private final PendingDevices pendingDevices;
private final MemcachedClient memcachedClient; private final MemcachedClient memcachedClient;
public PendingDevicesManager(PendingDeviceRegistrations pendingDevices, public PendingDevicesManager(PendingDevices pendingDevices,
MemcachedClient memcachedClient) MemcachedClient memcachedClient)
{ {
this.pendingDevices = pendingDevices; this.pendingDevices = pendingDevices;
@ -42,8 +42,10 @@ public class PendingDevicesManager {
} }
public void remove(String number) { public void remove(String number) {
if (memcachedClient != null) if (memcachedClient != null) {
memcachedClient.delete(MEMCACHE_PREFIX + number); memcachedClient.delete(MEMCACHE_PREFIX + number);
}
pendingDevices.remove(number); pendingDevices.remove(number);
} }

View File

@ -16,6 +16,7 @@
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import java.io.IOException; import java.io.IOException;
@ -27,7 +28,9 @@ public class StoredMessageManager {
this.storedMessages = storedMessages; this.storedMessages = storedMessages;
} }
public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage) throws IOException { public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
throws CryptoEncodingException
{
storedMessages.insert(device.getId(), outgoingMessage.serialize()); storedMessages.insert(device.getId(), outgoingMessage.serialize());
} }

View File

@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.federation.FederatedClient; import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle; import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle;
@ -53,22 +53,23 @@ public class DirectoryUpdater {
BatchOperationHandle batchOperation = directory.startBatchOperation(); BatchOperationHandle batchOperation = directory.startBatchOperation();
try { try {
Iterator<Device> accounts = accountsManager.getAllMasterDevices(); Iterator<Account> accounts = accountsManager.getAll();
if (accounts == null) if (accounts == null)
return; return;
while (accounts.hasNext()) { while (accounts.hasNext()) {
Device device = accounts.next(); Account account = accounts.next();
if (device.isActive()) {
byte[] token = Util.getContactToken(device.getNumber()); if (account.isActive()) {
ClientContact clientContact = new ClientContact(token, null, device.getSupportsSms()); byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
directory.add(batchOperation, clientContact); directory.add(batchOperation, clientContact);
logger.debug("Adding local token: " + Base64.encodeBytesWithoutPadding(token)); logger.debug("Adding local token: " + Base64.encodeBytesWithoutPadding(token));
} else { } else {
directory.remove(batchOperation, device.getNumber()); directory.remove(batchOperation, account.getNumber());
} }
} }
} finally { } finally {

View File

@ -77,19 +77,24 @@
</changeSet> </changeSet>
<changeSet id="2" author="matt"> <changeSet id="2" author="matt">
<addColumn tableName="accounts">
<column name="device_id" type="bigint">
<constraints nullable="false" />
</column>
<column name="fetches_messages" type="smallint" defaultValue="0"/> <addColumn tableName="accounts">
<column name="data" type="json" />
</addColumn> </addColumn>
<dropUniqueConstraint tableName="accounts" constraintName="accounts_number_key" /> <sql>UPDATE accounts SET data = CAST(('{"number" : "' || number || '", "supportsSms" : ' || supports_sms || ', "devices" : [{"id" : 1, "authToken" : "' || auth_token || '", "salt" : "' || salt || '"' || CASE WHEN signaling_key IS NOT NULL THEN ', "signalingKey" : "' || signaling_key || '"' ELSE '' END || CASE WHEN gcm_id IS NOT NULL THEN ', "gcmId" : "' || gcm_id || '"' ELSE '' END || CASE WHEN apn_id IS NOT NULL THEN ', "apnId" : "' || apn_id || '"' ELSE '' END || '}]}') AS json);</sql>
<addUniqueConstraint constraintName="account_number_device_unique" tableName="accounts" columnNames="number, device_id" />
<addNotNullConstraint tableName="accounts" columnName="data"/>
<dropColumn tableName="accounts" columnName="auth_token"/>
<dropColumn tableName="accounts" columnName="salt"/>
<dropColumn tableName="accounts" columnName="signaling_key"/>
<dropColumn tableName="accounts" columnName="gcm_id"/>
<dropColumn tableName="accounts" columnName="apn_id"/>
<dropColumn tableName="accounts" columnName="supports_sms"/>
<addColumn tableName="keys"> <addColumn tableName="keys">
<column name="device_id" type="bigint" > <column name="device_id" type="bigint" defaultValue="1">
<constraints nullable="false" /> <constraints nullable="false" />
</column> </column>
</addColumn> </addColumn>

View File

@ -1,61 +1,26 @@
package org.whispersystems.textsecuregcm.tests.controllers; package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
import com.yammer.dropwizard.testing.ResourceTest; import com.yammer.dropwizard.testing.ResourceTest;
import org.hibernate.validator.constraints.NotEmpty;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import javax.ws.rs.Path;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import static org.fest.assertions.api.Assertions.assertThat; import static org.fest.assertions.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
public class AccountControllerTest extends ResourceTest { public class AccountControllerTest extends ResourceTest {
/** The AccountAttributes used in protocol v1 (no fetchesMessages) */
static class V1AccountAttributes {
@JsonProperty
@NotEmpty
private String signalingKey;
@JsonProperty
private boolean supportsSms;
public V1AccountAttributes(String signalingKey, boolean supportsSms) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
}
}
@Path("/v1/accounts")
static class DumbVerificationAccountController extends AccountController {
public DumbVerificationAccountController(PendingAccountsManager pendingAccounts, AccountsManager accounts, RateLimiters rateLimiters, SmsSender smsSenderFactory) {
super(pendingAccounts, accounts, rateLimiters, smsSenderFactory);
}
@Override
protected VerificationCode generateVerificationCode() {
return new VerificationCode(5678901);
}
}
private static final String SENDER = "+14152222222"; private static final String SENDER = "+14152222222";
@ -75,15 +40,10 @@ public class AccountControllerTest extends ResourceTest {
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of("1234")); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of("1234"));
Mockito.doAnswer(new Answer() { addResource(new AccountController(pendingAccountsManager,
@Override accountsManager,
public Object answer(InvocationOnMock invocation) throws Throwable { rateLimiters,
((Device)invocation.getArguments()[0]).setDeviceId(2); smsSender));
return null;
}
}).when(accountsManager).provisionDevice(any(Device.class));
addResource(new DumbVerificationAccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender));
} }
@Test @Test
@ -102,17 +62,13 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response = ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1234")) client().resource(String.format("/v1/accounts/code/%s", "1234"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new V1AccountAttributes("keykeykeykey", false)) .entity(new AccountAttributes("keykeykeykey", false, false))
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class); .put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager).create(isA(Account.class)); verify(accountsManager).create(isA(Account.class));
ArgumentCaptor<String> number = ArgumentCaptor.forClass(String.class);
verify(pendingAccountsManager).remove(number.capture());
assertThat(number.getValue()).isEqualTo(SENDER);
} }
@Test @Test
@ -120,7 +76,7 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response = ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1111")) client().resource(String.format("/v1/accounts/code/%s", "1111"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar")) .header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new V1AccountAttributes("keykeykeykey", false)) .entity(new AccountAttributes("keykeykeykey", false, false))
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class); .put(ClientResponse.class);

View File

@ -19,15 +19,12 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.yammer.dropwizard.testing.ResourceTest; import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager; import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -37,10 +34,7 @@ import javax.ws.rs.Path;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import static org.fest.assertions.api.Assertions.assertThat; import static org.fest.assertions.api.Assertions.assertThat;
import static org.mockito.Matchers.any; import static org.mockito.Mockito.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class DeviceControllerTest extends ResourceTest { public class DeviceControllerTest extends ResourceTest {
@Path("/v1/devices") @Path("/v1/devices")
@ -55,12 +49,11 @@ public class DeviceControllerTest extends ResourceTest {
} }
} }
private static final String SENDER = "+14152222222";
private PendingDevicesManager pendingDevicesManager = mock(PendingDevicesManager.class); private PendingDevicesManager pendingDevicesManager = mock(PendingDevicesManager.class);
private AccountsManager accountsManager = mock(AccountsManager.class ); private AccountsManager accountsManager = mock(AccountsManager.class );
private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class ); private RateLimiter rateLimiter = mock(RateLimiter.class );
private Account account = mock(Account.class );
@Override @Override
protected void setUpResources() throws Exception { protected void setUpResources() throws Exception {
@ -70,15 +63,10 @@ public class DeviceControllerTest extends ResourceTest {
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of("5678901")); when(account.getNextDeviceId()).thenReturn(42L);
Mockito.doAnswer(new Answer() { when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of("5678901"));
@Override when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
public Object answer(InvocationOnMock invocation) throws Throwable {
((Device) invocation.getArguments()[0]).setDeviceId(2);
return null;
}
}).when(accountsManager).provisionDevice(any(Device.class));
addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, rateLimiters)); addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, rateLimiters));
} }
@ -91,19 +79,14 @@ public class DeviceControllerTest extends ResourceTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
Long deviceId = client().resource(String.format("/v1/devices/5678901")) DeviceResponse response = client().resource("/v1/devices/5678901")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.entity(new AccountAttributes("keykeykeykey", false, true)) .entity(new AccountAttributes("keykeykeykey", false, true))
.type(MediaType.APPLICATION_JSON_TYPE) .type(MediaType.APPLICATION_JSON_TYPE)
.put(Long.class); .put(DeviceResponse.class);
assertThat(deviceId).isNotEqualTo(AuthHelper.DEFAULT_DEVICE_ID);
ArgumentCaptor<Device> newAccount = ArgumentCaptor.forClass(Device.class); assertThat(response.getDeviceId()).isEqualTo(42L);
verify(accountsManager).provisionDevice(newAccount.capture());
assertThat(deviceId).isEqualTo(newAccount.getValue().getDeviceId());
ArgumentCaptor<String> number = ArgumentCaptor.forClass(String.class); verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
verify(pendingDevicesManager).remove(number.capture());
assertThat(number.getValue()).isEqualTo(AuthHelper.VALID_NUMBER);
} }
} }

View File

@ -2,7 +2,6 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
import com.sun.jersey.api.client.GenericType;
import com.yammer.dropwizard.testing.ResourceTest; import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.KeysController;
@ -10,13 +9,10 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.util.Arrays;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -28,42 +24,32 @@ public class KeyControllerTest extends ResourceTest {
private final String EXISTS_NUMBER = "+14152222222"; private final String EXISTS_NUMBER = "+14152222222";
private final String NOT_EXISTS_NUMBER = "+14152222220"; private final String NOT_EXISTS_NUMBER = "+14152222220";
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, AuthHelper.DEFAULT_DEVICE_ID, 1234, "test1", "test2", false); private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false); private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false);
private final Keys keys = mock(Keys.class); private final Keys keys = mock(Keys.class);
Device[] fakeDevice;
Account existsAccount;
@Override @Override
protected void setUpResources() { protected void setUpResources() {
addProvider(AuthHelper.getAuthenticator()); addProvider(AuthHelper.getAuthenticator());
RateLimiters rateLimiters = mock(RateLimiters.class); RateLimiters rateLimiters = mock(RateLimiters.class);
RateLimiter rateLimiter = mock(RateLimiter.class ); RateLimiter rateLimiter = mock(RateLimiter.class );
AccountsManager accounts = mock(AccountsManager.class);
fakeDevice = new Device[2];
fakeDevice[0] = mock(Device.class);
fakeDevice[1] = mock(Device.class);
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(keys.get(eq(EXISTS_NUMBER), isA(Account.class))).thenReturn(new UnstructuredPreKeyList(Arrays.asList(SAMPLE_KEY, SAMPLE_KEY2))); when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY));
when(keys.get(eq(NOT_EXISTS_NUMBER), isA(Account.class))).thenReturn(null); when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<PreKey>absent());
when(fakeDevice[0].getDeviceId()).thenReturn(AuthHelper.DEFAULT_DEVICE_ID); List<PreKey> allKeys = new LinkedList<>();
when(fakeDevice[1].getDeviceId()).thenReturn((long) 2); allKeys.add(SAMPLE_KEY);
allKeys.add(SAMPLE_KEY2);
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(new UnstructuredPreKeyList(allKeys)));
when(accounts.getAccount(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount)); addResource(new KeysController(rateLimiters, keys, null));
when(accounts.getAccount(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>absent());
addResource(new KeysController(rateLimiters, keys, accounts, null));
} }
@Test @Test
public void validRequestsTest() throws Exception { public void validLegacyRequestTest() throws Exception {
PreKey result = client().resource(String.format("/v1/keys/%s", EXISTS_NUMBER)) PreKey result = client().resource(String.format("/v1/keys/%s", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKey.class); .get(PreKey.class);
@ -75,15 +61,20 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getId() == 0); assertThat(result.getId() == 0);
assertThat(result.getNumber() == null); assertThat(result.getNumber() == null);
verify(keys).get(eq(EXISTS_NUMBER), eq(existsAccount)); verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
}
List<PreKey> results = client().resource(String.format("/v1/keys/%s?multikeys", EXISTS_NUMBER)) @Test
public void validMultiRequestTest() throws Exception {
UnstructuredPreKeyList results = client().resource(String.format("/v1/keys/%s/*", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(new GenericType<List<PreKey>>(){}); .get(UnstructuredPreKeyList.class);
assertThat(results.getKeys().size()).isEqualTo(2);
PreKey result = results.getKeys().get(0);
assertThat(results.size()).isEqualTo(2);
result = results.get(0);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY.getIdentityKey());
@ -91,18 +82,19 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getId() == 0); assertThat(result.getId() == 0);
assertThat(result.getNumber() == null); assertThat(result.getNumber() == null);
result = results.get(1); result = results.getKeys().get(1);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId()); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey()); assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY2.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY2.getIdentityKey());
assertThat(result.getId() == 1); assertThat(result.getId() == 0);
assertThat(result.getNumber() == null); assertThat(result.getNumber() == null);
verify(keys, times(2)).get(eq(EXISTS_NUMBER), eq(existsAccount)); verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
} }
@Test @Test
public void invalidRequestTest() throws Exception { public void invalidRequestTest() throws Exception {
ClientResponse response = client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER)) ClientResponse response = client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
@ -110,7 +102,6 @@ public class KeyControllerTest extends ResourceTest {
.get(ClientResponse.class); .get(ClientResponse.class);
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404); assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(404);
verifyNoMoreInteractions(keys);
} }
@Test @Test

View File

@ -1,56 +1,48 @@
package org.whispersystems.textsecuregcm.tests.util; package org.whispersystems.textsecuregcm.tests.util;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.auth.DeviceAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator; import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider; import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
import org.whispersystems.textsecuregcm.configuration.FederationConfiguration; import org.whispersystems.textsecuregcm.configuration.FederationConfiguration;
import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import java.util.Arrays; import java.util.Arrays;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class AuthHelper { public class AuthHelper {
public static final long DEFAULT_DEVICE_ID = 1;
public static final String VALID_NUMBER = "+14150000000"; public static final String VALID_NUMBER = "+14150000000";
public static final String VALID_PASSWORD = "foo"; public static final String VALID_PASSWORD = "foo";
public static final String INVVALID_NUMBER = "+14151111111"; public static final String INVVALID_NUMBER = "+14151111111";
public static final String INVALID_PASSWORD = "bar"; public static final String INVALID_PASSWORD = "bar";
public static final String VALID_FEDERATION_PEER = "valid_peer"; public static MultiBasicAuthProvider<FederatedPeer, Account> getAuthenticator() {
public static final String FEDERATION_PEER_TOKEN = "magic"; AccountsManager accounts = mock(AccountsManager.class );
Account account = mock(Account.class );
public static MultiBasicAuthProvider<FederatedPeer, Device> getAuthenticator() { Device device = mock(Device.class );
FederationConfiguration federationConfig = mock(FederationConfiguration.class);
when(federationConfig.getPeers()).thenReturn(Arrays.asList(new FederatedPeer(VALID_FEDERATION_PEER, "", FEDERATION_PEER_TOKEN, "")));
AccountsManager accounts = mock(AccountsManager.class);
Device device = mock(Device.class);
AuthenticationCredentials credentials = mock(AuthenticationCredentials.class); AuthenticationCredentials credentials = mock(AuthenticationCredentials.class);
when(credentials.verify("foo")).thenReturn(true); when(credentials.verify("foo")).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials); when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(accounts.get(VALID_NUMBER, DEFAULT_DEVICE_ID)).thenReturn(Optional.of(device)); when(account.getDevice(anyLong())).thenReturn(Optional.of(device));
when(accounts.get(VALID_NUMBER)).thenReturn(Optional.of(account));
return new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(federationConfig), return new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(new FederationConfiguration()),
FederatedPeer.class, FederatedPeer.class,
new DeviceAuthenticator(accounts), new AccountAuthenticator(accounts),
Device.class, "WhisperServer"); Account.class, "WhisperServer");
} }
public static String getAuthHeader(String number, String password) { public static String getAuthHeader(String number, String password) {
return "Basic " + Base64.encodeBytes((number + ":" + password).getBytes()); return "Basic " + Base64.encodeBytes((number + ":" + password).getBytes());
} }
public static String getV2AuthHeader(String number, long deviceId, String password) {
return "Basic " + Base64.encodeBytes((number + "." + deviceId + ":" + password).getBytes());
}
} }

View File

@ -1,11 +0,0 @@
package org.whispersystems.textsecuregcm.tests;
/**
* Created with IntelliJ IDEA.
* User: moxie
* Date: 10/28/13
* Time: 12:53 PM
* To change this template use File | Settings | File Templates.
*/
public class BaseTest {
}

View File

@ -1,127 +0,0 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import com.sun.jersey.api.client.GenericType;
import com.yammer.dropwizard.testing.ResourceTest;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.FederationController;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.MissingDevicesException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.RelayMessage;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import javax.ws.rs.core.MediaType;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class FederatedControllerTest extends ResourceTest {
private final String EXISTS_NUMBER = "+14152222222";
private final String EXISTS_NUMBER_2 = "+14154444444";
private final String NOT_EXISTS_NUMBER = "+14152222220";
private final Keys keys = mock(Keys.class);
private final PushSender pushSender = mock(PushSender.class);
Device[] fakeDevice;
Account existsAccount;
@Override
protected void setUpResources() throws MissingDevicesException {
addProvider(AuthHelper.getAuthenticator());
RateLimiters rateLimiters = mock(RateLimiters.class);
RateLimiter rateLimiter = mock(RateLimiter.class );
AccountsManager accounts = mock(AccountsManager.class);
fakeDevice = new Device[2];
fakeDevice[0] = new Device(42, EXISTS_NUMBER, 1, "", "", "", null, null, true, false);
fakeDevice[1] = new Device(43, EXISTS_NUMBER, 2, "", "", "", null, null, false, true);
existsAccount = new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1]));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
Map<String, Set<Long>> validOneElementSet = new HashMap<>();
validOneElementSet.put(EXISTS_NUMBER_2, new HashSet<>(Arrays.asList((long) 1)));
List<Account> validOneAccount = Arrays.asList(new Account(EXISTS_NUMBER_2, true,
Arrays.asList(new Device(44, EXISTS_NUMBER_2, 1, "", "", "", null, null, true, false))));
Map<String, Set<Long>> validTwoElementsSet = new HashMap<>();
validTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1, (long) 2)));
List<Account> validTwoAccount = Arrays.asList(new Account(EXISTS_NUMBER, true, Arrays.asList(fakeDevice[0], fakeDevice[1])));
Map<String, Set<Long>> invalidTwoElementsSet = new HashMap<>();
invalidTwoElementsSet.put(EXISTS_NUMBER, new HashSet<>(Arrays.asList((long) 1)));
when(accounts.getAccountsForDevices(eq(validOneElementSet))).thenReturn(validOneAccount);
when(accounts.getAccountsForDevices(eq(validTwoElementsSet))).thenReturn(validTwoAccount);
when(accounts.getAccountsForDevices(eq(invalidTwoElementsSet))).thenThrow(new MissingDevicesException(new HashSet<>(Arrays.asList(EXISTS_NUMBER))));
addResource(new FederationController(keys, accounts, pushSender, mock(UrlSigner.class)));
}
@Test
public void validRequestsTest() throws Exception {
MessageResponse result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER_2, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER_2));
assertThat(result.getFailure()).isEmpty();
assertThat(result.getNumbersMissingDevices()).isEmpty();
result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray()),
new RelayMessage(EXISTS_NUMBER, 2, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEqualTo(Arrays.asList(EXISTS_NUMBER, EXISTS_NUMBER + "." + 2));
assertThat(result.getFailure()).isEmpty();
assertThat(result.getNumbersMissingDevices()).isEmpty();
}
@Test
public void invalidRequestTest() throws Exception {
MessageResponse result = client().resource("/v1/federation/message")
.entity(Arrays.asList(new RelayMessage(EXISTS_NUMBER, 1, MessageProtos.OutgoingMessageSignal.newBuilder().build().toByteArray())))
.type(MediaType.APPLICATION_JSON)
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_FEDERATION_PEER, AuthHelper.FEDERATION_PEER_TOKEN))
.put(MessageResponse.class);
assertThat(result.getSuccess()).isEmpty();
assertThat(result.getFailure()).isEqualTo(Arrays.asList(EXISTS_NUMBER));
assertThat(result.getNumbersMissingDevices()).isEqualTo(new HashSet<>(Arrays.asList(EXISTS_NUMBER)));
}
}