diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 1fdad9571..e350bbf0b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -49,7 +49,6 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.entities.SetKeysRequest; import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -67,7 +66,6 @@ public class KeysController { private final RateLimiters rateLimiters; private final KeysManager keys; private final AccountsManager accounts; - private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); private static final CompletableFuture[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; @@ -237,37 +235,33 @@ public class KeysController { final List responseItems = new ArrayList<>(devices.size()); final List> tasks = devices.stream().map(device -> { + final CompletableFuture> unsignedEcPreKeyFuture = + keys.takeEC(targetIdentifier.uuid(), device.getId()); - ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType()); + final CompletableFuture> signedEcPreKeyFuture = + keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()); - final CompletableFuture> unsignedEcPreKeyFuture = keys.takeEC(targetIdentifier.uuid(), - device.getId()); final CompletableFuture> pqPreKeyFuture = returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) : CompletableFuture.completedFuture(Optional.empty()); - return pqPreKeyFuture.thenCombine(unsignedEcPreKeyFuture, - (maybePqPreKey, maybeUnsignedEcPreKey) -> { + return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture) + .thenAccept(ignored -> { + final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null); + final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null); + final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null); - KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null); - ECPreKey unsignedECPreKey = unsignedEcPreKeyFuture.join().orElse(null); - - compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), - keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())); - - if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { + if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) { final int registrationId = switch (targetIdentifier.identityType()) { case ACI -> device.getRegistrationId(); case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); }; responseItems.add( - new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, + new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey, pqPreKey)); } - - return null; - }).thenRun(Util.NOOP); + }); }) .toList(); @@ -278,6 +272,7 @@ public class KeysController { if (responseItems.isEmpty()) { throw new WebApplicationException(Response.Status.NOT_FOUND); } + return new PreKeyResponse(identityKey, responseItems); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 5f80be141..249e234c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -41,41 +41,37 @@ class KeysGrpcHelper { return devices .filter(Device::isEnabled) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) - .flatMap(device -> { - final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(identityType); + .flatMap(device -> Flux.merge( + Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), + Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())), + Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) + .flatMap(Mono::justOrEmpty) + .reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> { + if (preKey instanceof ECPreKey ecPreKey) { + builder.setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build()); + } else if (preKey instanceof ECSignedPreKey ecSignedPreKey) { + builder.setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()); + } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { + builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build()); + } else { + throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); + } - final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder() - .setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(ecSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) - .build()); - - return Flux.merge( - Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) - .flatMap(Mono::justOrEmpty) - .reduce(preKeyBundleBuilder, (builder, preKey) -> { - if (preKey instanceof ECPreKey ecPreKey) { - builder.setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) - .build()); - } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { - preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(kemSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) - .build()); - } else { - throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); - } - - return builder; - }) - // Cast device IDs to `int` to match data types in the response object’s protobuf definition - .map(builder -> Tuples.of((int) device.getId(), builder.build())); + return builder; }) + // Cast device IDs to `int` to match data types in the response object’s protobuf definition + .map(builder -> Tuples.of((int) device.getId(), builder.build()))) .collectMap(Tuple2::getT1, Tuple2::getT2) .map(preKeyBundles -> GetPreKeysResponse.newBuilder() .setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 790bce398..0ceab834c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.util.List; import java.util.OptionalInt; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -238,6 +239,10 @@ public class Device { this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; } + /** + * @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead + */ + @Deprecated public ECSignedPreKey getSignedPreKey(final IdentityType identityType) { return switch (identityType) { case ACI -> signedPreKey; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 7eac5b2db..64e337372 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -202,14 +202,6 @@ class KeysControllerTest { when(sampleDevice2.isEnabled()).thenReturn(true); when(sampleDevice3.isEnabled()).thenReturn(false); when(sampleDevice4.isEnabled()).thenReturn(true); - when(sampleDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY); - when(sampleDevice2.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY2); - when(sampleDevice3.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY3); - when(sampleDevice4.getSignedPreKey(IdentityType.ACI)).thenReturn(null); - when(sampleDevice.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY); - when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2); - when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3); - when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null); when(sampleDevice.getId()).thenReturn(sampleDeviceId); when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); @@ -260,6 +252,24 @@ class KeysControllerTest { when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDeviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY))); + + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice2Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY2))); + + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice3Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY3))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDeviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice2Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY2))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice3Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn( @@ -272,9 +282,13 @@ class KeysControllerTest { when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); - when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY); - when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(null); + + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_SIGNED_KEY))); + + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_PNI_SIGNED_KEY))); } @AfterEach @@ -365,8 +379,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -389,8 +402,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -412,8 +424,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -434,8 +445,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -456,8 +466,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -480,8 +489,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -516,8 +524,7 @@ class KeysControllerTest { assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index 3ede64437..f5183c9b2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -96,9 +96,14 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest when(keysManager.takeEC(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + ecSignedPreKeys.forEach((deviceId, preKey) -> when(keysManager.getEcSignedPreKey(identifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));