Move account existence check to be before rate limit

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2017-02-16 17:05:48 -08:00
parent 571c7a8069
commit dd6c5292fd
7 changed files with 45 additions and 370 deletions

View File

@ -40,7 +40,6 @@ import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.controllers.FederationControllerV1;
import org.whispersystems.textsecuregcm.controllers.FederationControllerV2;
import org.whispersystems.textsecuregcm.controllers.KeepAliveController;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV1;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ProvisioningController;
@ -199,7 +198,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(pushSender);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysControllerV1 keysControllerV1 = new KeysControllerV1(rateLimiters, keys, accountsManager, federatedClientManager);
KeysControllerV2 keysControllerV2 = new KeysControllerV2(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager);
@ -216,12 +214,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, messagesManager, new TimeProvider(), authorizationKey, turnTokenGenerator, config.getTestDevices()));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, rateLimiters));
environment.jersey().register(new DirectoryController(rateLimiters, directory));
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController));
environment.jersey().register(new FederationControllerV2(accountsManager, attachmentController, messageController, keysControllerV2));
environment.jersey().register(new ReceiptController(receiptSender));
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));
environment.jersey().register(attachmentController);
environment.jersey().register(keysControllerV1);
environment.jersey().register(keysControllerV2);
environment.jersey().register(messageController);

View File

@ -25,8 +25,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
@ -53,15 +51,11 @@ public class FederationControllerV1 extends FederationController {
private static final int ACCOUNT_CHUNK_SIZE = 10000;
private final KeysControllerV1 keysControllerV1;
public FederationControllerV1(AccountsManager accounts,
AttachmentController attachmentController,
MessageController messageController,
KeysControllerV1 keysControllerV1)
MessageController messageController)
{
super(accounts, attachmentController, messageController);
this.keysControllerV1 = keysControllerV1;
}
@Timed
@ -76,41 +70,6 @@ public class FederationControllerV1 extends FederationController {
attachmentId, Optional.<String>absent());
}
@Timed
@GET
@Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyV1> getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
throws IOException
{
try {
return keysControllerV1.get(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV1> getKeysV1(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysControllerV1.getDeviceKey(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@PUT
@Path("/messages/{source}/{sourceDeviceId}/{destination}")

View File

@ -68,32 +68,19 @@ public class KeysController {
return new PreKeyCount(count);
}
protected TargetKeys getLocalKeys(String number, String deviceIdSelector)
protected Optional<List<KeyRecord>> getLocalKeys(Account destination, String deviceIdSelector)
throws NoSuchUserException
{
Optional<Account> destination = accounts.get(number);
if (!destination.isPresent() || !destination.get().isActive()) {
throw new NoSuchUserException("Target account is inactive");
}
try {
if (deviceIdSelector.equals("*")) {
Optional<List<KeyRecord>> preKeys = keys.get(number);
return new TargetKeys(destination.get(), preKeys);
return keys.get(destination.getNumber());
}
long deviceId = Long.parseLong(deviceIdSelector);
Optional<Device> targetDevice = destination.get().getDevice(deviceId);
if (!targetDevice.isPresent() || !targetDevice.get().isActive()) {
throw new NoSuchUserException("Target device is inactive.");
}
long deviceId = Long.parseLong(deviceIdSelector);
for (int i=0;i<20;i++) {
try {
Optional<List<KeyRecord>> preKeys = keys.get(number, deviceId);
return new TargetKeys(destination.get(), preKeys);
return keys.get(destination.getNumber(), deviceId);
} catch (UnableToExecuteStatementException e) {
logger.info(e.getMessage());
}
@ -104,24 +91,4 @@ public class KeysController {
throw new WebApplicationException(Response.status(422).build());
}
}
public static class TargetKeys {
private final Account destination;
private final Optional<List<KeyRecord>> keys;
public TargetKeys(Account destination, Optional<List<KeyRecord>> keys) {
this.destination = destination;
this.keys = keys;
}
public Optional<List<KeyRecord>> getKeys() {
return keys;
}
public Account getDestination() {
return destination;
}
}
}

View File

@ -1,136 +0,0 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.Auth;
@Path("/v1/keys")
public class KeysControllerV1 extends KeysController {
private final Logger logger = LoggerFactory.getLogger(KeysControllerV1.class);
public KeysControllerV1(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager)
{
super(rateLimiters, keys, accounts, federatedClientManager);
}
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Account account, @Valid PreKeyStateV1 preKeys) {
Device device = account.getAuthenticatedDevice().get();
String identityKey = preKeys.getLastResortKey().getIdentityKey();
if (!identityKey.equals(account.getIdentityKey())) {
account.setIdentityKey(identityKey);
accounts.update(account);
}
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
}
@Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV1> getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
try {
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (relay.isPresent()) {
return federatedClientManager.getClient(relay.get()).getKeysV1(number, deviceId);
}
TargetKeys targetKeys = getLocalKeys(number, deviceId);
if (!targetKeys.getKeys().isPresent()) {
return Optional.absent();
}
List<PreKeyV1> preKeys = new LinkedList<>();
Account destination = targetKeys.getDestination();
for (KeyRecord record : targetKeys.getKeys().get()) {
Optional<Device> device = destination.getDevice(record.getDeviceId());
if (device.isPresent() && device.get().isActive()) {
preKeys.add(new PreKeyV1(record.getDeviceId(), record.getKeyId(),
record.getPublicKey(), destination.getIdentityKey(),
device.get().getRegistrationId()));
}
}
if (preKeys.isEmpty()) return Optional.absent();
else return Optional.of(new PreKeyResponseV1(preKeys));
} catch (NoSuchPeerException | NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
}
}
@Timed
@GET
@Path("/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyV1> get(@Auth Account account,
@PathParam("number") String number,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
Optional<PreKeyResponseV1> results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
if (results.isPresent()) return Optional.of(results.get().getKeys().get(0));
else return Optional.absent();
}
}

View File

@ -94,25 +94,26 @@ public class KeysControllerV2 extends KeysController {
throws RateLimitExceededException
{
try {
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (relay.isPresent()) {
return federatedClientManager.getClient(relay.get()).getKeysV2(number, deviceId);
}
TargetKeys targetKeys = getLocalKeys(number, deviceId);
Account destination = targetKeys.getDestination();
Account target = getAccount(number, deviceId);
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
Optional<List<KeyRecord>> targetKeys = getLocalKeys(target, deviceId);
List<PreKeyResponseItemV2> devices = new LinkedList<>();
for (Device device : destination.getDevices()) {
for (Device device : target.getDevices()) {
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey();
PreKeyV2 preKey = null;
if (targetKeys.getKeys().isPresent()) {
for (KeyRecord keyRecord : targetKeys.getKeys().get()) {
if (targetKeys.isPresent()) {
for (KeyRecord keyRecord : targetKeys.get()) {
if (keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKeyV2(keyRecord.getKeyId(), keyRecord.getPublicKey());
}
@ -126,7 +127,7 @@ public class KeysControllerV2 extends KeysController {
}
if (devices.isEmpty()) return Optional.absent();
else return Optional.of(new PreKeyResponseV2(destination.getIdentityKey(), devices));
else return Optional.of(new PreKeyResponseV2(target.getIdentityKey(), devices));
} catch (NoSuchPeerException | NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
}
@ -153,4 +154,30 @@ public class KeysControllerV2 extends KeysController {
if (signedPreKey != null) return Optional.of(signedPreKey);
else return Optional.absent();
}
private Account getAccount(String number, String deviceSelector)
throws NoSuchUserException
{
try {
Optional<Account> account = accounts.get(number);
if (!account.isPresent() || !account.get().isActive()) {
throw new NoSuchUserException("No active account");
}
if (!deviceSelector.equals("*")) {
long deviceId = Long.parseLong(deviceSelector);
Optional<Device> targetDevice = account.get().getDevice(deviceId);
if (!targetDevice.isPresent() || !targetDevice.get().isActive()) {
throw new NoSuchUserException("No active device");
}
}
return account.get();
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
}
}

View File

@ -70,7 +70,7 @@ public class FederatedControllerTest {
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new FederationControllerV1(accountsManager, null, messageController, null))
.addResource(new FederationControllerV1(accountsManager, null, messageController))
.addResource(new FederationControllerV2(accountsManager, null, messageController, keysControllerV2))
.build();

View File

@ -7,14 +7,10 @@ import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.dropwizard.simpleauth.AuthValueFactoryProvider;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV1;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV1;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV2;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV2;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
@ -69,7 +65,6 @@ public class KeyControllerTest {
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new KeysControllerV1(rateLimiters, keys, accounts, null))
.addResource(new KeysControllerV2(rateLimiters, keys, accounts, null))
.build();
@ -112,6 +107,7 @@ public class KeyControllerTest {
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isActive()).thenReturn(true);
when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey");
when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER);
when(accounts.get(EXISTS_NUMBER)).thenReturn(Optional.of(existsAccount));
when(accounts.get(NOT_EXISTS_NUMBER)).thenReturn(Optional.<Account>absent());
@ -137,20 +133,6 @@ public class KeyControllerTest {
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@Test
public void validKeyStatusTestV1() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v1/keys")
.request()
.header("Authorization",
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class);
assertThat(result.getCount() == 4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
}
@Test
public void validKeyStatusTestV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
@ -191,22 +173,6 @@ public class KeyControllerTest {
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
}
@Test
public void validLegacyRequestTest() throws Exception {
PreKeyV1 result = resources.getJerseyTest()
.target(String.format("/v1/keys/%s", EXISTS_NUMBER))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyV1.class);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys);
}
@Test
public void validSingleRequestTestV2() throws Exception {
PreKeyResponseV2 result = resources.getJerseyTest()
@ -225,40 +191,6 @@ public class KeyControllerTest {
verifyNoMoreInteractions(keys);
}
@Test
public void validMultiRequestTestV1() throws Exception {
PreKeyResponseV1 results = resources.getJerseyTest()
.target(String.format("/v1/keys/%s/*", EXISTS_NUMBER))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponseV1.class);
assertThat(results.getKeys().size()).isEqualTo(3);
PreKeyV1 result = results.getKeys().get(0);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
result = results.getKeys().get(1);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID2);
result = results.getKeys().get(2);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID4);
verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys);
}
@Test
public void validMultiRequestTestV2() throws Exception {
PreKeyResponseV2 results = resources.getJerseyTest()
@ -309,18 +241,6 @@ public class KeyControllerTest {
verifyNoMoreInteractions(keys);
}
@Test
public void invalidRequestTestV1() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get();
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(404);
}
@Test
public void invalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
@ -343,26 +263,6 @@ public class KeyControllerTest {
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(404);
}
@Test
public void unauthorizedRequestTestV1() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD))
.get();
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(401);
response =
resources.getJerseyTest()
.target(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
.request()
.get();
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(401);
}
@Test
public void unauthorizedRequestTestV2() throws Exception {
Response response =
@ -383,45 +283,6 @@ public class KeyControllerTest {
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(401);
}
@Test
public void putKeysTestV1() throws Exception {
final PreKeyV1 newKey = new PreKeyV1(1L, 31337, "foobar", "foobarbaz");
final PreKeyV1 lastResortKey = new PreKeyV1(1L, 0xFFFFFF, "fooz", "foobarbaz");
List<PreKeyV1> preKeys = new LinkedList<PreKeyV1>() {{
add(newKey);
}};
PreKeyStateV1 preKeyList = new PreKeyStateV1();
preKeyList.setKeys(preKeys);
preKeyList.setLastResortKey(lastResortKey);
Response response =
resources.getJerseyTest()
.target("/v1/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyList, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class );
ArgumentCaptor<PreKeyV1> lastResortCaptor = ArgumentCaptor.forClass(PreKeyV1.class);
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture(), lastResortCaptor.capture());
List<PreKeyV1> capturedList = listCaptor.getValue();
assertThat(capturedList.size() == 1);
assertThat(capturedList.get(0).getIdentityKey().equals("foobarbaz"));
assertThat(capturedList.get(0).getKeyId() == 31337);
assertThat(capturedList.get(0).getPublicKey().equals("foobar"));
assertThat(lastResortCaptor.getValue().getPublicKey().equals("fooz"));
assertThat(lastResortCaptor.getValue().getIdentityKey().equals("foobarbaz"));
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("foobarbaz"));
verify(accounts).update(AuthHelper.VALID_ACCOUNT);
}
@Test
public void putKeysTestV2() throws Exception {
final PreKeyV2 preKey = new PreKeyV2(31337, "foobar");