Add support for "registrationId" session enforcement.

This commit is contained in:
Moxie Marlinspike 2014-02-20 09:32:42 -08:00
parent 35e212a30f
commit f4ecb5d7be
18 changed files with 204 additions and 32 deletions

View File

@ -9,7 +9,7 @@
<groupId>org.whispersystems.textsecure</groupId>
<artifactId>TextSecureServer</artifactId>
<version>0.3</version>
<version>0.4</version>
<dependencies>
<dependency>

View File

@ -131,7 +131,7 @@ public class WhisperServerService extends Service<WhisperServerConfiguration> {
accountsManager);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, federatedClientManager);
KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
environment.addProvider(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),

View File

@ -140,6 +140,7 @@ public class AccountController {
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
Account account = new Account();
account.setNumber(number);

View File

@ -28,6 +28,7 @@ import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
@ -42,6 +43,8 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List;
@Path("/v1/keys")
public class KeysController {
@ -50,13 +53,15 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final Keys keys;
private final AccountsManager accounts;
private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys,
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.keys = keys;
this.accounts = accounts;
this.federatedClientManager = federatedClientManager;
}
@ -108,18 +113,50 @@ public class KeysController {
return results.getKeys().get(0);
}
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceId) {
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceIdSelector) {
Optional<Account> destination = accounts.get(number);
if (!destination.isPresent() || !destination.get().isActive()) {
return Optional.absent();
}
try {
if (deviceId.equals("*")) {
return keys.get(number);
if (deviceIdSelector.equals("*")) {
Optional<UnstructuredPreKeyList> preKeys = keys.get(number);
return getActiveKeys(destination.get(), preKeys);
}
Optional<PreKey> targetKey = keys.get(number, Long.parseLong(deviceId));
long deviceId = Long.parseLong(deviceIdSelector);
Optional<Device> targetDevice = destination.get().getDevice(deviceId);
if (targetKey.isPresent()) return Optional.of(new UnstructuredPreKeyList(targetKey.get()));
else return Optional.absent();
if (!targetDevice.isPresent() || !targetDevice.get().isActive()) {
return Optional.absent();
}
Optional<UnstructuredPreKeyList> preKeys = keys.get(number, deviceId);
return getActiveKeys(destination.get(), preKeys);
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
private Optional<UnstructuredPreKeyList> getActiveKeys(Account destination,
Optional<UnstructuredPreKeyList> preKeys)
{
if (!preKeys.isPresent()) return Optional.absent();
List<PreKey> filteredKeys = new LinkedList<>();
for (PreKey preKey : preKeys.get().getKeys()) {
Optional<Device> device = destination.getDevice(preKey.getDeviceId());
if (device.isPresent() && device.get().isActive()) {
preKey.setRegistrationId(device.get().getRegistrationId());
filteredKeys.add(preKey);
}
}
if (filteredKeys.isEmpty()) return Optional.absent();
else return Optional.of(new UnstructuredPreKeyList(filteredKeys));
}
}

View File

@ -27,6 +27,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
@ -98,6 +99,11 @@ public class MessageController {
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
}
}
@ -124,11 +130,12 @@ public class MessageController {
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages)
throws NoSuchUserException, MismatchedDevicesException, IOException
throws NoSuchUserException, MismatchedDevicesException, IOException, StaleDevicesException
{
Account destination = getDestinationAccount(destinationName);
validateCompleteDeviceList(destination, messages.getMessages());
validateRegistrationIds(destination, messages.getMessages());
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
@ -197,6 +204,27 @@ public class MessageController {
return account.get();
}
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException
{
List<Long> staleDevices = new LinkedList<>();
for (IncomingMessage message : messages) {
Optional<Device> device = account.getDevice(message.getDestinationDeviceId());
if (device.isPresent() &&
message.getDestinationRegistrationId() > 0 &&
message.getDestinationRegistrationId() != device.get().getRegistrationId())
{
staleDevices.add(device.get().getId());
}
}
if (!staleDevices.isEmpty()) {
throw new StaleDevicesException(staleDevices);
}
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages)
throws MismatchedDevicesException
{
@ -211,10 +239,12 @@ public class MessageController {
}
for (Device device : account.getDevices()) {
accountDeviceIds.add(device.getId());
if (device.isActive()) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {
missingDeviceIds.add(device.getId());
}
}
}

View File

@ -0,0 +1,16 @@
package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Throwable {
private final List<Long> staleDevices;
public StaleDevicesException(List<Long> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Long> getStaleDevices() {
return staleDevices;
}
}

View File

@ -31,12 +31,16 @@ public class AccountAttributes {
@JsonProperty
private boolean fetchesMessages;
@JsonProperty
private int registrationId;
public AccountAttributes() {}
public AccountAttributes(String signalingKey, boolean supportsSms, boolean fetchesMessages) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
public AccountAttributes(String signalingKey, boolean supportsSms, boolean fetchesMessages, int registrationId) {
this.signalingKey = signalingKey;
this.supportsSms = supportsSms;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
}
public String getSignalingKey() {
@ -51,4 +55,7 @@ public class AccountAttributes {
return fetchesMessages;
}
public int getRegistrationId() {
return registrationId;
}
}

View File

@ -30,6 +30,9 @@ public class IncomingMessage {
@JsonProperty
private long destinationDeviceId = 1;
@JsonProperty
private int destinationRegistrationId;
@JsonProperty
@NotEmpty
private String body;
@ -40,6 +43,7 @@ public class IncomingMessage {
@JsonProperty
private long timestamp;
public String getDestination() {
return destination;
}
@ -59,4 +63,8 @@ public class IncomingMessage {
public long getDestinationDeviceId() {
return destinationDeviceId;
}
public int getDestinationRegistrationId() {
return destinationRegistrationId;
}
}

View File

@ -52,6 +52,9 @@ public class PreKey {
@JsonProperty
private boolean lastResort;
@JsonProperty
private int registrationId;
public PreKey() {}
public PreKey(long id, String number, long deviceId, long keyId,
@ -125,4 +128,12 @@ public class PreKey {
public long getDeviceId() {
return deviceId;
}
public int getRegistrationId() {
return registrationId;
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
}

View File

@ -0,0 +1,18 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
public class StaleDevices {
@JsonProperty
private List<Long> staleDevices;
public StaleDevices() {}
public StaleDevices(List<Long> staleDevices) {
this.staleDevices = staleDevices;
}
}

View File

@ -48,11 +48,14 @@ public class Device implements Serializable {
@JsonProperty
private boolean fetchesMessages;
@JsonProperty
private int registrationId;
public Device() {}
public Device(long id, String authToken, String salt,
String signalingKey, String gcmId, String apnId,
boolean fetchesMessages)
boolean fetchesMessages, int registrationId)
{
this.id = id;
this.authToken = authToken;
@ -61,6 +64,7 @@ public class Device implements Serializable {
this.gcmId = gcmId;
this.apnId = apnId;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
}
public String getApnId() {
@ -119,4 +123,12 @@ public class Device implements Serializable {
public boolean isMaster() {
return getId() == MASTER_ID;
}
public int getRegistrationId() {
return registrationId;
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
}

View File

@ -84,14 +84,14 @@ public abstract class Keys {
}
@Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<PreKey> get(String number, long deviceId) {
public Optional<UnstructuredPreKeyList> get(String number, long deviceId) {
PreKey preKey = retrieveFirst(number, deviceId);
if (preKey != null && !preKey.isLastResort()) {
removeKey(preKey.getId());
}
if (preKey != null) return Optional.of(preKey);
if (preKey != null) return Optional.of(new UnstructuredPreKeyList(preKey));
else return Optional.absent();
}

View File

@ -62,7 +62,7 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1234"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false))
.entity(new AccountAttributes("keykeykeykey", false, false, 2222))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
@ -76,7 +76,7 @@ public class AccountControllerTest extends ResourceTest {
ClientResponse response =
client().resource(String.format("/v1/accounts/code/%s", "1111"))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false))
.entity(new AccountAttributes("keykeykeykey", false, false, 3333))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);

View File

@ -83,7 +83,7 @@ public class DeviceControllerTest extends ResourceTest {
DeviceResponse response = client().resource("/v1/devices/5678901")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.entity(new AccountAttributes("keykeykeykey", false, true))
.entity(new AccountAttributes("keykeykeykey", false, true, 1234))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(DeviceResponse.class);

View File

@ -9,6 +9,8 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -24,9 +26,14 @@ public class KeyControllerTest extends ResourceTest {
private final String EXISTS_NUMBER = "+14152222222";
private final String NOT_EXISTS_NUMBER = "+14152222220";
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false);
private final Keys keys = mock(Keys.class);
private final int SAMPLE_REGISTRATION_ID = 999;
private final int SAMPLE_REGISTRATION_ID2 = 1002;
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4", false );
private final PreKey SAMPLE_KEY3 = new PreKey(3, EXISTS_NUMBER, 3, 334, "test5", "test6", false);
private final Keys keys = mock(Keys.class );
private final AccountsManager accounts = mock(AccountsManager.class);
@Override
protected void setUpResources() {
@ -35,17 +42,38 @@ public class KeyControllerTest extends ResourceTest {
RateLimiters rateLimiters = mock(RateLimiters.class);
RateLimiter rateLimiter = mock(RateLimiter.class );
Device sampleDevice = mock(Device.class );
Device sampleDevice2 = mock(Device.class);
Device sampleDevice3 = mock(Device.class);
Account existsAccount = mock(Account.class);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice.isActive()).thenReturn(true);
when(sampleDevice2.isActive()).thenReturn(true);
when(sampleDevice3.isActive()).thenReturn(false);
when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.isActive()).thenReturn(true);
when(accounts.get(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount));
when(accounts.get(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>absent());
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY));
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<PreKey>absent());
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(new UnstructuredPreKeyList(SAMPLE_KEY)));
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<UnstructuredPreKeyList>absent());
List<PreKey> allKeys = new LinkedList<>();
allKeys.add(SAMPLE_KEY);
allKeys.add(SAMPLE_KEY2);
allKeys.add(SAMPLE_KEY3);
when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(new UnstructuredPreKeyList(allKeys)));
addResource(new KeysController(rateLimiters, keys, null));
addResource(new KeysController(rateLimiters, keys, accounts, null));
}
@Test
@ -78,6 +106,7 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
@ -86,6 +115,7 @@ public class KeyControllerTest extends ResourceTest {
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(SAMPLE_KEY2.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);

View File

@ -53,12 +53,12 @@ public class MessageControllerTest extends ResourceTest {
addProvider(AuthHelper.getAuthenticator());
List<Device> singleDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111));
}};
List<Device> multiDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);

View File

@ -15,6 +15,7 @@ public class PreKeyTest {
@Test
public void serializeToJSON() throws Exception {
PreKey preKey = new PreKey(1, "+14152222222", 1, 1234, "test", "identityTest", false);
preKey.setRegistrationId(987);
assertThat("Basic Contact Serialization works",
asJson(preKey),

View File

@ -2,5 +2,6 @@
"deviceId" : 1,
"keyId" : 1234,
"publicKey" : "test",
"identityKey" : "identityTest"
"identityKey" : "identityTest",
"registrationId" : 987
}