diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 41f949082..baa9382e4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -52,6 +52,7 @@ import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeyState; +import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; @@ -64,9 +65,10 @@ import org.whispersystems.textsecuregcm.storage.KeysManager; @Tag(name = "Keys") public class KeysController { - private final RateLimiters rateLimiters; + private final RateLimiters rateLimiters; private final KeysManager keys; - private final AccountsManager accounts; + private final AccountsManager accounts; + private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange"); private static final String IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME = name(KeysController.class, "identityKeyChangeForbidden"); @@ -224,6 +226,9 @@ public class KeysController { ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; + compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), + keys.getEcSignedPreKey(identifier, device.getId())); + if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { final int registrationId = usePhoneNumberIdentity ? device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 9b099758c..ad94fddc7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; @@ -221,6 +222,8 @@ class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); + when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); @@ -325,6 +328,7 @@ class KeysControllerTest { assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -348,6 +352,7 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -369,6 +374,7 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -388,6 +394,7 @@ class KeysControllerTest { assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); } @@ -409,6 +416,7 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_PNI, 1); verify(KEYS).takePQ(EXISTS_PNI, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); } @@ -430,6 +438,7 @@ class KeysControllerTest { assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); } @@ -465,6 +474,7 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -557,6 +567,9 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 2); verify(KEYS).takeEC(EXISTS_UUID, 4); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4); verifyNoMoreInteractions(KEYS); } @@ -622,6 +635,9 @@ class KeysControllerTest { verify(KEYS).takePQ(EXISTS_UUID, 2); verify(KEYS).takeEC(EXISTS_UUID, 4); verify(KEYS).takePQ(EXISTS_UUID, 4); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4); verifyNoMoreInteractions(KEYS); }