diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index ef436ff8e..2af1c748e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -420,7 +420,7 @@ public class WhisperServerService extends Application 0) { count = count - 1; @@ -98,7 +104,7 @@ public class KeysController { } } - keys.store(account, device.getId(), preKeys.getPreKeys()); + getPreKeyStoreForProducer(account).store(account, device.getId(), preKeys.getPreKeys()); } @Timed @@ -179,14 +185,22 @@ public class KeysController { private List getLocalKeys(Account destination, String deviceIdSelector) { try { if (deviceIdSelector.equals("*")) { - return keys.take(destination); + return getPreKeyStoreForConsumer(destination).take(destination); } long deviceId = Long.parseLong(deviceIdSelector); - return keys.take(destination, deviceId); + return getPreKeyStoreForConsumer(destination).take(destination, deviceId); } catch (NumberFormatException e) { throw new WebApplicationException(Response.status(422).build()); } } + + private PreKeyStore getPreKeyStoreForProducer(final Account account) { + return experimentEnrollmentManager.isEnrolled(account.getUuid(), DYNAMODB_PRODUCER_EXPERIMENT) ? keysDynamoDb : keys; + } + + private PreKeyStore getPreKeyStoreForConsumer(final Account account) { + return experimentEnrollmentManager.isEnrolled(account.getUuid(), DYNAMODB_CONSUMER_EXPERIMENT) ? keysDynamoDb : keys; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index bb2ec3d7a..af2c64812 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -21,6 +21,7 @@ import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; @@ -29,6 +30,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeyRecord; import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import javax.ws.rs.client.Entity; @@ -69,10 +71,12 @@ public class KeysControllerTest { private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" ); private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid"); - private final Keys keys = mock(Keys.class ); - private final AccountsManager accounts = mock(AccountsManager.class); - private final DirectoryQueue directoryQueue = mock(DirectoryQueue.class); - private final Account existsAccount = mock(Account.class ); + private final Keys keys = mock(Keys.class ); + private final KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class ); + private final AccountsManager accounts = mock(AccountsManager.class ); + private final DirectoryQueue directoryQueue = mock(DirectoryQueue.class ); + private final Account existsAccount = mock(Account.class ); + private final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); private RateLimiters rateLimiters = mock(RateLimiters.class); private RateLimiter rateLimiter = mock(RateLimiter.class ); @@ -82,7 +86,7 @@ public class KeysControllerTest { .addProvider(AuthHelper.getAuthFilter()) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new KeysController(rateLimiters, keys, accounts, directoryQueue)) + .addResource(new KeysController(rateLimiters, keys, keysDynamoDb, accounts, directoryQueue, experimentEnrollmentManager)) .build(); @Before