From d18f576239583420935b7e6a34b0a87dfe09846e Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 20 Dec 2023 17:17:40 -0500 Subject: [PATCH] Revert "Revert "Treat the stand-alone signed pre-keys table as the source of truth for signed pre-keys"" This reverts commit 3f9edfe597d873b7b75e84b51715e49963a9b805. --- .../controllers/KeysController.java | 43 +++---------- .../textsecuregcm/grpc/KeysGrpcHelper.java | 62 +++++++++---------- .../textsecuregcm/storage/Device.java | 5 ++ .../controllers/KeysControllerTest.java | 55 +++++++++------- .../grpc/KeysAnonymousGrpcServiceTest.java | 11 +++- .../grpc/KeysGrpcServiceTest.java | 4 +- 6 files changed, 84 insertions(+), 96 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index dfa763c94..47aeeec34 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -38,8 +38,6 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.signal.libsignal.protocol.IdentityKey; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; @@ -51,7 +49,6 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.entities.SetKeysRequest; import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -74,14 +71,11 @@ public class KeysController { private final RateLimiters rateLimiters; private final KeysManager keys; private final AccountsManager accounts; - private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys"); private static final CompletableFuture[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; - private static final Logger logger = LoggerFactory.getLogger(KeysController.class); - public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) { this.rateLimiters = rateLimiters; this.keys = keys; @@ -247,35 +241,13 @@ public class KeysController { .increment(); final List responseItems = Flux.fromIterable(parseDeviceId(deviceId, target)) - .flatMap(device -> { - final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(targetIdentifier.identityType()); - final CompletableFuture> signedPreKeyFuture = - keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()); - - compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(ecSignedPreKey), signedPreKeyFuture); - - signedPreKeyFuture.whenComplete((maybeSignedPreKey, throwable) -> { - if (throwable == null) { - if (!Optional.ofNullable(ecSignedPreKey).equals(maybeSignedPreKey)) { - logger.warn("Signed pre-keys do not match for {}, device {}. From device: {}; from table: {}", - targetIdentifier, - deviceId, - Optional.ofNullable(ecSignedPreKey).map(ECSignedPreKey::keyId), - maybeSignedPreKey.map(ECSignedPreKey::keyId)); - } - } else { - logger.error("Failed to get signed pre-key for {}, device {}", targetIdentifier, deviceId, throwable); - } - }); - - return Mono.zip( - Mono.just(device), - Mono.just(Optional.ofNullable(ecSignedPreKey)), - Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), - Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) - : CompletableFuture.>completedFuture(Optional.empty())) - ); - }).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) + .flatMap(device -> Mono.zip( + Mono.just(device), + Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())), + Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), + Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) + : CompletableFuture.>completedFuture(Optional.empty())) + )).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) .map(deviceAndKeys -> { final Device device = deviceAndKeys.getT1(); final int registrationId = switch (targetIdentifier.identityType()) { @@ -296,6 +268,7 @@ public class KeysController { if (responseItems.isEmpty()) { throw new WebApplicationException(Response.Status.NOT_FOUND); } + return new PreKeyResponse(identityKey, responseItems); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index 5f80be141..249e234c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -41,41 +41,37 @@ class KeysGrpcHelper { return devices .filter(Device::isEnabled) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) - .flatMap(device -> { - final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(identityType); + .flatMap(device -> Flux.merge( + Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), + Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())), + Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) + .flatMap(Mono::justOrEmpty) + .reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> { + if (preKey instanceof ECPreKey ecPreKey) { + builder.setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build()); + } else if (preKey instanceof ECSignedPreKey ecSignedPreKey) { + builder.setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()); + } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { + builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build()); + } else { + throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); + } - final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder() - .setEcSignedPreKey(EcSignedPreKey.newBuilder() - .setKeyId(ecSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) - .build()); - - return Flux.merge( - Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), - Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId()))) - .flatMap(Mono::justOrEmpty) - .reduce(preKeyBundleBuilder, (builder, preKey) -> { - if (preKey instanceof ECPreKey ecPreKey) { - builder.setEcOneTimePreKey(EcPreKey.newBuilder() - .setKeyId(ecPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) - .build()); - } else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) { - preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() - .setKeyId(kemSignedPreKey.keyId()) - .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) - .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) - .build()); - } else { - throw new AssertionError("Unexpected pre-key type: " + preKey.getClass()); - } - - return builder; - }) - // Cast device IDs to `int` to match data types in the response object’s protobuf definition - .map(builder -> Tuples.of((int) device.getId(), builder.build())); + return builder; }) + // Cast device IDs to `int` to match data types in the response object’s protobuf definition + .map(builder -> Tuples.of((int) device.getId(), builder.build()))) .collectMap(Tuple2::getT1, Tuple2::getT2) .map(preKeyBundles -> GetPreKeysResponse.newBuilder() .setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize())) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 626458766..b32f04899 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -11,6 +11,7 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.time.Duration; import java.util.List; import java.util.OptionalInt; +import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; @@ -246,6 +247,10 @@ public class Device { this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; } + /** + * @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead + */ + @Deprecated public ECSignedPreKey getSignedPreKey(final IdentityType identityType) { return switch (identityType) { case ACI -> signedPreKey; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 062de0091..6c2890e73 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -199,14 +199,6 @@ class KeysControllerTest { when(sampleDevice2.isEnabled()).thenReturn(true); when(sampleDevice3.isEnabled()).thenReturn(false); when(sampleDevice4.isEnabled()).thenReturn(true); - when(sampleDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY); - when(sampleDevice2.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY2); - when(sampleDevice3.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY3); - when(sampleDevice4.getSignedPreKey(IdentityType.ACI)).thenReturn(null); - when(sampleDevice.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY); - when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2); - when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3); - when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null); when(sampleDevice.getId()).thenReturn(sampleDeviceId); when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); @@ -257,6 +249,24 @@ class KeysControllerTest { when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDeviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY))); + + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice2Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY2))); + + when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice3Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY3))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDeviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice2Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY2))); + + when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice3Id)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn( @@ -269,9 +279,13 @@ class KeysControllerTest { when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); - when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY); - when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(null); + + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_SIGNED_KEY))); + + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_PNI_SIGNED_KEY))); } @AfterEach @@ -350,8 +364,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -374,8 +387,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -397,8 +409,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); @@ -419,8 +430,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -441,8 +451,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -465,8 +474,7 @@ class KeysControllerTest { assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); @@ -501,8 +509,7 @@ class KeysControllerTest { assertThat(result.getDevicesCount()).isEqualTo(1); assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); - assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); + assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java index 3ede64437..f5183c9b2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -96,9 +96,14 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest when(keysManager.takeEC(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + ecSignedPreKeys.forEach((deviceId, preKey) -> when(keysManager.getEcSignedPreKey(identifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId)) .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));