diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 17191e79e..bfe530a68 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -50,6 +50,7 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.entities.SetKeysRequest; import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -71,6 +72,7 @@ public class KeysController { private final RateLimiters rateLimiters; private final KeysManager keys; private final AccountsManager accounts; + private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys"); @@ -243,13 +245,20 @@ public class KeysController { .increment(); final List responseItems = Flux.fromIterable(parseDeviceId(deviceId, target)) - .flatMap(device -> Mono.zip( - Mono.just(device), - Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())), - Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), - Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) - : CompletableFuture.>completedFuture(Optional.empty())) - )).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) + .flatMap(device -> { + final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(targetIdentifier.identityType()); + + compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(ecSignedPreKey), + keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())); + + return Mono.zip( + Mono.just(device), + Mono.just(Optional.ofNullable(ecSignedPreKey)), + Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), + Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) + : CompletableFuture.>completedFuture(Optional.empty())) + ); + }).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) .map(deviceAndKeys -> { final Device device = deviceAndKeys.getT1(); final int registrationId = switch (targetIdentifier.identityType()) { @@ -270,7 +279,6 @@ public class KeysController { if (responseItems.isEmpty()) { throw new WebApplicationException(Response.Status.NOT_FOUND); } - return new PreKeyResponse(identityKey, responseItems); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 249e234c2..5f80be141 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -41,37 +41,41 @@ class KeysGrpcHelper { return devices .filter(Device::isEnabled) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) - .flatMap(device -> Flux.merge( - Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) - .flatMap(Mono::justOrEmpty) - .reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> { - if (preKey instanceof ECPreKey ecPreKey) { - builder.setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) - .build()); - } else if (preKey instanceof ECSignedPreKey ecSignedPreKey) { - builder.setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(ecSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) - .build()); - } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { - builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(kemSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) - .build()); - } else { - throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); - } + .flatMap(device -> { + final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(identityType); - return builder; + final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()); + + return Flux.merge( + Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), + Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) + .flatMap(Mono::justOrEmpty) + .reduce(preKeyBundleBuilder, (builder, preKey) -> { + if (preKey instanceof ECPreKey ecPreKey) { + builder.setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build()); + } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { + preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build()); + } else { + throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); + } + + return builder; + }) + // Cast device IDs to `int` to match data types in the response object’s protobuf definition + .map(builder -> Tuples.of((int) device.getId(), builder.build())); }) - // Cast device IDs to `int` to match data types in the response object’s protobuf definition - .map(builder -> Tuples.of((int) device.getId(), builder.build()))) .collectMap(Tuple2::getT1, Tuple2::getT2) .map(preKeyBundles -> GetPreKeysResponse.newBuilder() .setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 17ad0abba..d01642403 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -10,7 +10,6 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.util.List; import java.util.OptionalInt; -import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -239,10 +238,6 @@ public class Device { this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; } - /** - * @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead - */ - @Deprecated public ECSignedPreKey getSignedPreKey(final IdentityType identityType) { return switch (identityType) { case ACI -> signedPreKey; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 64e337372..7eac5b2db 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -202,6 +202,14 @@ class KeysControllerTest { when(sampleDevice2.isEnabled()).thenReturn(true); when(sampleDevice3.isEnabled()).thenReturn(false); when(sampleDevice4.isEnabled()).thenReturn(true); + when(sampleDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY); + when(sampleDevice2.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY2); + when(sampleDevice3.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY3); + when(sampleDevice4.getSignedPreKey(IdentityType.ACI)).thenReturn(null); + when(sampleDevice.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY); + when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2); + when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3); + when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null); when(sampleDevice.getId()).thenReturn(sampleDeviceId); when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); @@ -252,24 +260,6 @@ class KeysControllerTest { when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDeviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY))); - - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice2Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY2))); - - when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice3Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY3))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDeviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice2Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY2))); - - when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice3Id)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY3))); - when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn( @@ -282,13 +272,9 @@ class KeysControllerTest { when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); + when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY); + when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(null); - - when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId())) - .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_SIGNED_KEY))); - - when(KEYS.getEcSignedPreKey(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId())) - .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_PNI_SIGNED_KEY))); } @AfterEach @@ -379,7 +365,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -402,7 +389,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -424,7 +412,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -445,7 +434,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -466,7 +456,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -489,7 +480,8 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -524,7 +516,8 @@ class KeysControllerTest { assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); - assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index f5183c9b2..3ede64437 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -96,14 +96,9 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest when(keysManager.takeEC(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); - ecSignedPreKeys.forEach((deviceId, preKey) -> when(keysManager.getEcSignedPreKey(identifier, deviceId)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); - kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));