From 595cc55578783b05a103aafcbb2f5210e28f1bbc Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 18 Jan 2024 12:10:22 -0500 Subject: [PATCH] Retire the `returnPqKey` flag when fetching pre-keys --- .../controllers/KeysController.java | 10 ++------ .../controllers/KeysControllerTest.java | 23 +++++++++++++++---- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 0660449db..b499d1cce 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -17,7 +17,6 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -186,10 +185,6 @@ public class KeysController { @Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices") @PathParam("device_id") String deviceId, - @Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys") - @Schema(defaultValue="false") - @QueryParam("pq") boolean returnPqKey, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { @@ -229,9 +224,8 @@ public class KeysController { final CompletableFuture> signedEcPreKeyFuture = keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()); - final CompletableFuture> pqPreKeyFuture = returnPqKey - ? keysManager.takePQ(targetIdentifier.uuid(), device.getId()) - : CompletableFuture.completedFuture(Optional.empty()); + final CompletableFuture> pqPreKeyFuture = + keysManager.takePQ(targetIdentifier.uuid(), device.getId()); return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture) .thenAccept(ignored -> { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 008ac207b..1e76cbceb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -28,7 +28,6 @@ import io.dropwizard.testing.junit5.ResourceExtension; import java.time.Duration; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.UUID; @@ -110,6 +109,7 @@ class KeysControllerTest { private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair()); + private final KEMSignedPreKey SAMPLE_PQ_KEY4 = KeysHelper.signedKEMPreKey(7676, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair()); @@ -349,11 +349,12 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); - assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -415,11 +416,12 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); - assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -459,11 +461,12 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); - assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -555,6 +558,15 @@ class KeysControllerTest { when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY4))); + PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .request() @@ -597,6 +609,9 @@ class KeysControllerTest { verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);