From 5627209fdd44b4d235c31bb9b1ada8ecfcfb2290 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Thu, 20 Jul 2023 11:10:26 -0400 Subject: [PATCH] Add a gRPC service for working with pre-keys --- pom.xml | 16 + service/pom.xml | 5 + .../textsecuregcm/WhisperServerService.java | 14 +- .../auth/UnidentifiedAccessUtil.java | 32 + .../auth/grpc/AuthenticatedDevice.java | 11 + .../auth/grpc/AuthenticationUtil.java | 28 +- ...icCredentialAuthenticationInterceptor.java | 5 +- .../auth/grpc/NotAuthenticatedException.java | 13 - .../textsecuregcm/grpc/IdentityType.java | 21 + .../grpc/KeysAnonymousGrpcService.java | 40 ++ .../textsecuregcm/grpc/KeysGrpcHelper.java | 107 +++ .../textsecuregcm/grpc/KeysGrpcService.java | 307 ++++++++ .../textsecuregcm/grpc/RateLimitUtil.java | 45 ++ .../textsecuregcm/limits/RateLimiter.java | 5 + .../textsecuregcm/storage/KeysManager.java | 8 + .../util/NoStackTraceRuntimeException.java | 29 + .../textsecuregcm/util/UUIDUtil.java | 9 + .../main/proto/org/signal/chat/common.proto | 71 ++ .../src/main/proto/org/signal/chat/keys.proto | 263 +++++++ .../auth/UnidentifiedAccessUtilTest.java | 52 ++ .../grpc/MockAuthenticationInterceptor.java | 46 ++ .../grpc/GrpcServerExtension.java | 119 +++ .../grpc/KeysAnonymousGrpcServiceTest.java | 211 ++++++ .../grpc/KeysGrpcServiceTest.java | 678 ++++++++++++++++++ 24 files changed, 2112 insertions(+), 23 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java create mode 100644 service/src/main/proto/org/signal/chat/common.proto create mode 100644 service/src/main/proto/org/signal/chat/keys.proto create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java diff --git a/pom.xml b/pom.xml index a299fa84c..7610e7e4d 100644 --- a/pom.xml +++ b/pom.xml @@ -62,6 +62,7 @@ 1.2.0 3.21.7 0.15.2 + 1.2.4 1.7.0 3.1.0 1.7.30 @@ -124,6 +125,11 @@ pom import + + com.salesforce.servicelibs + reactor-grpc-stub + ${reactive.grpc.version} + io.github.resilience4j resilience4j-bom @@ -398,6 +404,16 @@ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} grpc-java io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} + + + + reactor-grpc + com.salesforce.servicelibs + reactor-grpc + ${reactive.grpc.version} + com.salesforce.reactorgrpc.ReactorGrpcGenerator + + diff --git a/service/pom.xml b/service/pom.xml index a6a7b35c7..9ea5159e7 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -286,6 +286,11 @@ jackson-jaxrs-json-provider + + com.salesforce.servicelibs + reactor-grpc-stub + + software.amazon.awssdk sts diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 4a3bcf50d..2939bab4c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -24,6 +24,7 @@ import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.setup.Bootstrap; import io.dropwizard.setup.Environment; import io.grpc.ServerBuilder; +import io.grpc.ServerInterceptors; import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder; import io.lettuce.core.metrics.MicrometerOptions; import io.lettuce.core.resource.ClientResources; @@ -64,6 +65,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; @@ -72,6 +74,7 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; +import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor; import org.whispersystems.textsecuregcm.badges.ConfiguredProfileBadgeConverter; import org.whispersystems.textsecuregcm.badges.ResourceBundleLevelTranslator; import org.whispersystems.textsecuregcm.captcha.CaptchaChecker; @@ -115,6 +118,8 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; +import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; +import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService; import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -401,6 +406,7 @@ public class WhisperServerService extends Application disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator( disabledPermittedAccountAuthenticator).buildAuthFilter(); + final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor = + new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager)); + final ServerBuilder grpcServer = ServerBuilder.forPort(config.getGrpcPort()) - .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry)); /* TODO: specialize metrics with user-agent platform */ + // TODO: specialize metrics with user-agent platform + .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry)) + .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor)) + .addService(new KeysAnonymousGrpcService(accountsManager, keys)); RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager); environment.servlets() diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java new file mode 100644 index 000000000..32283170a --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import org.whispersystems.textsecuregcm.storage.Account; +import java.security.MessageDigest; + +public class UnidentifiedAccessUtil { + + private UnidentifiedAccessUtil() { + } + + /** + * Checks whether an action (e.g. sending a message or retrieving pre-keys) may be taken on the target account by an + * actor presenting the given unidentified access key. + * + * @param targetAccount the account on which an actor wishes to take an action + * @param unidentifiedAccessKey the unidentified access key presented by the actor + * + * @return {@code true} if an actor presenting the given unidentified access key has permission to take an action on + * the target account or {@code false} otherwise + */ + public static boolean checkUnidentifiedAccess(final Account targetAccount, final byte[] unidentifiedAccessKey) { + return targetAccount.isUnrestrictedUnidentifiedAccess() + || targetAccount.getUnidentifiedAccessKey() + .map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey)) + .orElse(false); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java new file mode 100644 index 000000000..906056986 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java @@ -0,0 +1,11 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth.grpc; + +import java.util.UUID; + +public record AuthenticatedDevice(UUID accountIdentifier, long deviceId) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java index b68f0cf4e..a449b2092 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java @@ -6,8 +6,12 @@ package org.whispersystems.textsecuregcm.auth.grpc; import io.grpc.Context; -import java.util.Optional; +import io.grpc.Status; import java.util.UUID; +import javax.annotation.Nullable; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; /** * Provides utility methods for working with authentication in the context of gRPC calls. @@ -17,11 +21,23 @@ public class AuthenticationUtil { static final Context.Key CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci"); static final Context.Key CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id"); - public static Optional getAuthenticatedAccountIdentifier() { - return Optional.ofNullable(CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get()); - } + /** + * Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if + * no authenticated account/device is available. + * + * @return the account/device authenticated in the current gRPC context + * + * @throws io.grpc.StatusRuntimeException with a status of {@code UNAUTHENTICATED} if no authenticated account/device + * could be retrieved from the current gRPC context + */ + public static AuthenticatedDevice requireAuthenticatedDevice() { + @Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get(); + @Nullable final Long deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get(); - public static Optional getAuthenticatedDeviceIdentifier() { - return Optional.ofNullable(CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get()); + if (accountIdentifier != null && deviceId != null) { + return new AuthenticatedDevice(accountIdentifier, deviceId); + } + + throw Status.UNAUTHENTICATED.asRuntimeException(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java index 307c00092..73a1df2b0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java @@ -24,9 +24,8 @@ import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; * Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the * {@code x-signal-basic-auth-credentials} call header. *

- * Downstream services can retrieve the identity of the authenticated caller using - * {@link AuthenticationUtil#getAuthenticatedAccountIdentifier()} and - * {@link AuthenticationUtil#getAuthenticatedDeviceIdentifier()}. + * Downstream services can retrieve the identity of the authenticated caller using methods in + * {@link AuthenticationUtil}. *

* Note that this authentication, while fully functional, is intended only for development and testing purposes and is * intended to be replaced with a more robust and efficient strategy before widespread client adoption. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java deleted file mode 100644 index 67e596476..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth.grpc; - -/** - * Indicates that a caller tried to get information about the authenticated gRPC caller, but no caller has been - * authenticated. - */ -public class NotAuthenticatedException extends Exception { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java new file mode 100644 index 000000000..d558348bf --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java @@ -0,0 +1,21 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Status; + +public enum IdentityType { + ACI, + PNI; + + public static IdentityType fromGrpcIdentityType(final org.signal.chat.common.IdentityType grpcIdentityType) { + return switch (grpcIdentityType) { + case IDENTITY_TYPE_ACI -> ACI; + case IDENTITY_TYPE_PNI -> PNI; + case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw Status.INVALID_ARGUMENT.asRuntimeException(); + }; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java new file mode 100644 index 000000000..bc6410ea2 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Status; +import org.signal.chat.keys.GetPreKeysAnonymousRequest; +import org.signal.chat.keys.GetPreKeysResponse; +import org.signal.chat.keys.ReactorKeysAnonymousGrpc; +import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Mono; + +public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnonymousImplBase { + + private final AccountsManager accountsManager; + private final KeysManager keysManager; + + public KeysAnonymousGrpcService(final AccountsManager accountsManager, final KeysManager keysManager) { + this.accountsManager = accountsManager; + this.keysManager = keysManager; + } + + @Override + public Mono getPreKeys(final GetPreKeysAnonymousRequest request) { + return KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager) + .switchIfEmpty(Mono.error(Status.UNAUTHENTICATED.asException())) + .flatMap(targetAccount -> { + final IdentityType identityType = + IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType()); + + return UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray()) + ? KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager) + : Mono.error(Status.UNAUTHENTICATED.asException()); + }); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java new file mode 100644 index 000000000..948e1aa73 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -0,0 +1,107 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; +import io.grpc.Status; +import java.util.UUID; +import org.signal.chat.common.EcPreKey; +import org.signal.chat.common.EcSignedPreKey; +import org.signal.chat.common.KemSignedPreKey; +import org.signal.chat.common.ServiceIdentifier; +import org.signal.chat.keys.GetPreKeysResponse; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +class KeysGrpcHelper { + + @VisibleForTesting + static final long ALL_DEVICES = 0; + + static Mono findAccount(final ServiceIdentifier targetIdentifier, final AccountsManager accountsManager) { + + return Mono.just(IdentityType.fromGrpcIdentityType(targetIdentifier.getIdentityType())) + .flatMap(identityType -> { + final UUID uuid = UUIDUtil.fromByteString(targetIdentifier.getUuid()); + + return Mono.fromFuture(switch (identityType) { + case ACI -> accountsManager.getByAccountIdentifierAsync(uuid); + case PNI -> accountsManager.getByPhoneNumberIdentifierAsync(uuid); + }); + }) + .flatMap(Mono::justOrEmpty) + .onErrorMap(IllegalArgumentException.class, throwable -> Status.INVALID_ARGUMENT.asException()); + } + + static Tuple2 getIdentifierAndIdentityKey(final Account account, final IdentityType identityType) { + final UUID identifier = switch (identityType) { + case ACI -> account.getUuid(); + case PNI -> account.getPhoneNumberIdentifier(); + }; + + final IdentityKey identityKey = switch (identityType) { + case ACI -> account.getIdentityKey(); + case PNI -> account.getPhoneNumberIdentityKey(); + }; + + return Tuples.of(identifier, identityKey); + } + + static Mono getPreKeys(final Account targetAccount, final IdentityType identityType, final long targetDeviceId, final KeysManager keysManager) { + final Tuple2 identifierAndIdentityKey = getIdentifierAndIdentityKey(targetAccount, identityType); + + final Flux devices = targetDeviceId == ALL_DEVICES + ? Flux.fromIterable(targetAccount.getDevices()) + : Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId))); + + return devices + .filter(Device::isEnabled) + .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) + .flatMap(device -> Mono.zip(Mono.fromFuture(keysManager.takeEC(identifierAndIdentityKey.getT1(), device.getId())), + Mono.fromFuture(keysManager.takePQ(identifierAndIdentityKey.getT1(), device.getId()))) + .map(oneTimePreKeys -> { + final ECSignedPreKey ecSignedPreKey = switch (identityType) { + case ACI -> device.getSignedPreKey(); + case PNI -> device.getPhoneNumberIdentitySignedPreKey(); + }; + + final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()); + + oneTimePreKeys.getT1().ifPresent(ecPreKey -> preKeyBundleBuilder.setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build())); + + oneTimePreKeys.getT2().ifPresent(kemSignedPreKey -> preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build())); + + return Tuples.of(device.getId(), preKeyBundleBuilder.build()); + })) + .collectMap(Tuple2::getT1, Tuple2::getT2) + .map(preKeyBundles -> GetPreKeysResponse.newBuilder() + .setIdentityKey(ByteString.copyFrom(identifierAndIdentityKey.getT2().serialize())) + .putAllPreKeys(preKeyBundles) + .build()); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java new file mode 100644 index 000000000..f09abb2b8 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -0,0 +1,307 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.whispersystems.textsecuregcm.grpc.IdentityType.ACI; + +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import org.signal.chat.common.EcPreKey; +import org.signal.chat.common.EcSignedPreKey; +import org.signal.chat.common.KemSignedPreKey; +import org.signal.chat.keys.GetPreKeyCountRequest; +import org.signal.chat.keys.GetPreKeyCountResponse; +import org.signal.chat.keys.GetPreKeysRequest; +import org.signal.chat.keys.GetPreKeysResponse; +import org.signal.chat.keys.ReactorKeysGrpc; +import org.signal.chat.keys.SetEcSignedPreKeyRequest; +import org.signal.chat.keys.SetKemLastResortPreKeyRequest; +import org.signal.chat.keys.SetOneTimeEcPreKeysRequest; +import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest; +import org.signal.chat.keys.SetPreKeyResponse; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.InvalidKeyException; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.signal.libsignal.protocol.kem.KEMPublicKey; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { + + private final AccountsManager accountsManager; + private final KeysManager keysManager; + private final RateLimiters rateLimiters; + + private static final StatusRuntimeException INVALID_PUBLIC_KEY_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT) + .withDescription("Invalid public key") + .asRuntimeException(); + + private static final StatusRuntimeException INVALID_SIGNATURE_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT) + .withDescription("Invalid signature") + .asRuntimeException(); + + private enum PreKeyType { + EC, + KEM + } + + public KeysGrpcService(final AccountsManager accountsManager, + final KeysManager keysManager, + final RateLimiters rateLimiters) { + + this.accountsManager = accountsManager; + this.keysManager = keysManager; + this.rateLimiters = rateLimiters; + } + + @Override + protected Throwable onErrorMap(final Throwable throwable) { + return RateLimitUtil.mapRateLimitExceededException(throwable); + } + + @Override + public Mono getPreKeyCount(final GetPreKeyCountRequest request) { + return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) + .flatMap(authenticatedDevice -> Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier())) + .map(maybeAccount -> maybeAccount + .map(account -> Tuples.of(account, authenticatedDevice.deviceId())) + .orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))) + .flatMapMany(accountAndDeviceId -> Flux.just( + Tuples.of(ACI, accountAndDeviceId.getT1().getUuid(), accountAndDeviceId.getT2()), + Tuples.of(IdentityType.PNI, accountAndDeviceId.getT1().getPhoneNumberIdentifier(), accountAndDeviceId.getT2()) + )) + .flatMap(identityTypeUuidAndDeviceId -> Flux.merge( + Mono.fromFuture(keysManager.getEcCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3())) + .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.EC, ecKeyCount)), + + Mono.fromFuture(keysManager.getPqCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3())) + .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.KEM, ecKeyCount)) + )) + .reduce(GetPreKeyCountResponse.newBuilder(), (builder, tuple) -> { + final IdentityType identityType = tuple.getT1(); + final PreKeyType preKeyType = tuple.getT2(); + final int count = tuple.getT3(); + + switch (identityType) { + case ACI -> { + switch (preKeyType) { + case EC -> builder.setAciEcPreKeyCount(count); + case KEM -> builder.setAciKemPreKeyCount(count); + } + } + case PNI -> { + switch (preKeyType) { + case EC -> builder.setPniEcPreKeyCount(count); + case KEM -> builder.setPniKemPreKeyCount(count); + } + } + } + + return builder; + }) + .map(GetPreKeyCountResponse.Builder::build); + } + + @Override + public Mono getPreKeys(final GetPreKeysRequest request) { + final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice(); + + final String rateLimitKey; + { + final UUID targetUuid; + + try { + targetUuid = UUIDUtil.fromByteString(request.getTargetIdentifier().getUuid()); + } catch (final IllegalArgumentException e) { + throw Status.INVALID_ARGUMENT.asRuntimeException(); + } + + rateLimitKey = authenticatedDevice.accountIdentifier() + "." + + authenticatedDevice.deviceId() + "__" + + targetUuid + "." + + request.getDeviceId(); + } + + return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey) + .then(KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager)) + .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) + .flatMap(targetAccount -> { + final IdentityType identityType = + IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType()); + + return KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager); + }); + } + + @Override + public Mono setOneTimeEcPreKeys(final SetOneTimeEcPreKeysRequest request) { + return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) + .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), + request.getPreKeysList(), + IdentityType.fromGrpcIdentityType(request.getIdentityType()), + (requestPreKey, ignored) -> checkEcPreKey(requestPreKey), + (identifier, preKeys) -> keysManager.storeEcOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys))); + } + + @Override + public Mono setOneTimeKemSignedPreKeys(final SetOneTimeKemSignedPreKeysRequest request) { + return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) + .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(), + request.getPreKeysList(), + IdentityType.fromGrpcIdentityType(request.getIdentityType()), + KeysGrpcService::checkKemSignedPreKey, + (identifier, preKeys) -> keysManager.storeKemOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys))); + } + + private Mono storeOneTimePreKeys(final UUID authenticatedAccountUuid, + final List requestPreKeys, + final IdentityType identityType, + final BiFunction extractPreKeyFunction, + final BiFunction, CompletableFuture> storeKeysFunction) { + + return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid)) + .map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException)) + .map(account -> { + final Tuple2 identifierAndIdentityKey = + KeysGrpcHelper.getIdentifierAndIdentityKey(account, identityType); + + final List preKeys = requestPreKeys.stream() + .map(requestPreKey -> extractPreKeyFunction.apply(requestPreKey, identifierAndIdentityKey.getT2())) + .toList(); + + if (preKeys.isEmpty()) { + throw Status.INVALID_ARGUMENT.asRuntimeException(); + } + + return Tuples.of(identifierAndIdentityKey.getT1(), preKeys); + }) + .flatMap(identifierAndPreKeys -> Mono.fromFuture(storeKeysFunction.apply(identifierAndPreKeys.getT1(), identifierAndPreKeys.getT2()))) + .thenReturn(SetPreKeyResponse.newBuilder().build()); + } + + @Override + public Mono setEcSignedPreKey(final SetEcSignedPreKeyRequest request) { + return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) + .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), + request.getIdentityType(), + request.getSignedPreKey(), + KeysGrpcService::checkEcSignedPreKey, + (account, signedPreKey) -> { + final Consumer deviceUpdater = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) { + case ACI -> device -> device.setSignedPreKey(signedPreKey); + case PNI -> device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey); + }; + + final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) { + case ACI -> account.getUuid(); + case PNI -> account.getPhoneNumberIdentifier(); + }; + + return Flux.merge( + Mono.fromFuture(keysManager.storeEcSignedPreKeys(identifier, Map.of(authenticatedDevice.deviceId(), signedPreKey))), + Mono.fromFuture(accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater))) + .then(); + })); + } + + @Override + public Mono setKemLastResortPreKey(final SetKemLastResortPreKeyRequest request) { + return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice) + .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(), + request.getIdentityType(), + request.getSignedPreKey(), + KeysGrpcService::checkKemSignedPreKey, + (account, lastResortKey) -> { + final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) { + case ACI -> account.getUuid(); + case PNI -> account.getPhoneNumberIdentifier(); + }; + + return Mono.fromFuture(keysManager.storePqLastResort(identifier, Map.of(authenticatedDevice.deviceId(), lastResortKey))); + })); + } + + private Mono storeRepeatedUseKey(final UUID authenticatedAccountUuid, + final org.signal.chat.common.IdentityType identityType, + final R storeKeyRequest, + final BiFunction extractKeyFunction, + final BiFunction> storeKeyFunction) { + + return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid)) + .map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException)) + .map(account -> { + final IdentityKey identityKey = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> account.getIdentityKey(); + case PNI -> account.getPhoneNumberIdentityKey(); + }; + + final K key = extractKeyFunction.apply(storeKeyRequest, identityKey); + + return Tuples.of(account, key); + }) + .flatMap(accountAndKey -> storeKeyFunction.apply(accountAndKey.getT1(), accountAndKey.getT2())) + .thenReturn(SetPreKeyResponse.newBuilder().build()); + } + + private static ECPreKey checkEcPreKey(final EcPreKey preKey) { + try { + return new ECPreKey(preKey.getKeyId(), new ECPublicKey(preKey.getPublicKey().toByteArray())); + } catch (final InvalidKeyException e) { + throw INVALID_PUBLIC_KEY_EXCEPTION; + } + } + + private static ECSignedPreKey checkEcSignedPreKey(final EcSignedPreKey preKey, final IdentityKey identityKey) { + try { + final ECSignedPreKey ecSignedPreKey = new ECSignedPreKey(preKey.getKeyId(), + new ECPublicKey(preKey.getPublicKey().toByteArray()), + preKey.getSignature().toByteArray()); + + if (ecSignedPreKey.signatureValid(identityKey)) { + return ecSignedPreKey; + } else { + throw INVALID_SIGNATURE_EXCEPTION; + } + } catch (final InvalidKeyException e) { + throw INVALID_PUBLIC_KEY_EXCEPTION; + } + } + + private static KEMSignedPreKey checkKemSignedPreKey(final KemSignedPreKey preKey, final IdentityKey identityKey) { + try { + final KEMSignedPreKey kemSignedPreKey = new KEMSignedPreKey(preKey.getKeyId(), + new KEMPublicKey(preKey.getPublicKey().toByteArray()), + preKey.getSignature().toByteArray()); + + if (kemSignedPreKey.signatureValid(identityKey)) { + return kemSignedPreKey; + } else { + throw INVALID_SIGNATURE_EXCEPTION; + } + } catch (final InvalidKeyException e) { + throw INVALID_PUBLIC_KEY_EXCEPTION; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java new file mode 100644 index 000000000..dadfd5417 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import java.time.Duration; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; + +public class RateLimitUtil { + + public static final Metadata.Key RETRY_AFTER_DURATION_KEY = + Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() { + @Override + public String toAsciiString(final Duration value) { + return value.toString(); + } + + @Override + public Duration parseAsciiString(final String serialized) { + return Duration.parse(serialized); + } + }); + + public static Throwable mapRateLimitExceededException(final Throwable throwable) { + if (throwable instanceof RateLimitExceededException rateLimitExceededException) { + @Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration() + .map(duration -> { + final Metadata metadata = new Metadata(); + metadata.put(RETRY_AFTER_DURATION_KEY, duration); + + return metadata; + }).orElse(null); + + return new StatusException(Status.RESOURCE_EXHAUSTED, trailers); + } + + return throwable; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 70e7d1cd3..b53df9a1c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits; import java.util.UUID; import java.util.concurrent.CompletionStage; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import reactor.core.publisher.Mono; public interface RateLimiter { @@ -53,6 +54,10 @@ public interface RateLimiter { return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); } + default Mono validateReactive(final String key) { + return Mono.fromFuture(validateAsync(key).toCompletableFuture()); + } + default boolean hasAvailablePermits(final UUID accountUuid, final int permits) { return hasAvailablePermits(accountUuid.toString(), permits); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 90e0ea8c0..65167ceb3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -90,6 +90,14 @@ public class KeysManager { return pqLastResortKeys.store(identifier, keys); } + public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) { + return ecPreKeys.store(identifier, deviceId, preKeys); + } + + public CompletableFuture storeKemOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) { + return pqPreKeys.store(identifier, deviceId, preKeys); + } + public CompletableFuture> takeEC(final UUID identifier, final long deviceId) { return ecPreKeys.take(identifier, deviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java new file mode 100644 index 000000000..1fa8c6824 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +/** + * An abstract base class for runtime exceptions that do not include a stack trace. Stackless exceptions are generally + * intended for internal error-handling cases where the error will never be logged or otherwise reported. + */ +public abstract class NoStackTraceRuntimeException extends RuntimeException { + + public NoStackTraceRuntimeException() { + super(null, null, true, false); + } + + public NoStackTraceRuntimeException(final String message) { + super(message, null, true, false); + } + + public NoStackTraceRuntimeException(final String message, final Throwable cause) { + super(message, cause, true, false); + } + + public NoStackTraceRuntimeException(final Throwable cause) { + super(null, cause, true, false); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java index 3f1d91d7c..38f2cb34b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.util; +import com.google.protobuf.ByteString; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.util.Optional; @@ -27,6 +28,14 @@ public final class UUIDUtil { return byteBuffer.flip(); } + public static ByteString toByteString(final UUID uuid) { + return ByteString.copyFrom(toByteBuffer(uuid)); + } + + public static UUID fromByteString(final ByteString byteString) { + return fromBytes(byteString.toByteArray()); + } + public static UUID fromBytes(final byte[] bytes) { return fromByteBuffer(ByteBuffer.wrap(bytes)); } diff --git a/service/src/main/proto/org/signal/chat/common.proto b/service/src/main/proto/org/signal/chat/common.proto new file mode 100644 index 000000000..5add25f6c --- /dev/null +++ b/service/src/main/proto/org/signal/chat/common.proto @@ -0,0 +1,71 @@ +syntax = "proto3"; + +option java_multiple_files = true; + +package org.signal.chat.common; + +enum IdentityType { + IDENTITY_TYPE_UNSPECIFIED = 0; + IDENTITY_TYPE_ACI = 1; + IDENTITY_TYPE_PNI = 2; +} + +message ServiceIdentifier { + /** + * The type of identity represented by this service identifier. + */ + IdentityType identity_type = 1; + + /** + * The UUID of the identity represented by this service identifier. + */ + bytes uuid = 2; +} + +message EcPreKey { + /** + * A locally-unique identifier for this key. + */ + uint64 key_id = 1; + + /** + * The serialized form of the public key. + */ + bytes public_key = 2; +} + +message EcSignedPreKey { + /** + * A locally-unique identifier for this key. + */ + uint64 key_id = 1; + + /** + * The serialized form of the public key. + */ + bytes public_key = 2; + + /** + * A signature of the public key, verifiable with the identity key for the + * account/identity associated with this pre-key. + */ + bytes signature = 3; +} + +message KemSignedPreKey { + /** + * A locally-unique identifier for this key. + */ + uint64 key_id = 1; + + /** + * The serialized form of the public key. + */ + bytes public_key = 2; + + /** + * A signature of the public key, verifiable with the identity key for the + * account/identity associated with this pre-key. + */ + bytes signature = 3; +} diff --git a/service/src/main/proto/org/signal/chat/keys.proto b/service/src/main/proto/org/signal/chat/keys.proto new file mode 100644 index 000000000..7f7af8fb6 --- /dev/null +++ b/service/src/main/proto/org/signal/chat/keys.proto @@ -0,0 +1,263 @@ +syntax = "proto3"; + +option java_multiple_files = true; + +package org.signal.chat.keys; + +import "org/signal/chat/common.proto"; + +/** + * Provides methods for working with pre-keys. + */ +service Keys { + + /** + * Retrieves an approximate count of the number of the various kinds of + * pre-keys stored for the authenticated device. + */ + rpc GetPreKeyCount (GetPreKeyCountRequest) returns (GetPreKeyCountResponse) {} + + /** + * Retrieves a set of pre-keys for establishing a session with the targeted + * device or devices. Note that callers with an unidentified access key for + * the targeted account should use the version of this method in + * `KeysAnonymous` instead. + * + * This RPC may fail with a `NOT_FOUND` status if the target account was not + * found, if no active device with the given ID (if specified) was found on + * the target account, or if the account has no active devices. It may also + * fail with a `RESOURCE_EXHAUSTED` if a rate limit for fetching keys has been + * exceeded, in which case a `retry-after` header containing an ISO 8601 + * duration string will be present in the response trailers. + */ + rpc GetPreKeys(GetPreKeysRequest) returns (GetPreKeysResponse) {} + + /** + * Uploads a new set of one-time EC pre-keys for the authenticated device, + * clearing any previously-stored pre-keys. Note that all keys submitted via + * a single call to this method _must_ have the same identity type (i.e. if + * the first key has an ACI identity type, then all other keys in the same + * stream must also have an ACI identity type). + * + * This method returns a single status code; if all keys were validated/stored + * successfully, then this method will return `SET_PRE_KEY_STATUS_OK`. If one + * or more keys could not be stored, this method will return a status code + * indicating the reason. If multiple keys had problems, which status code + * will be returned is not defined. + * + * This RPC may fail with an `INVALID_ARGUMENT` status if one or more of the + * given pre-keys was structurally invalid or if the list of pre-keys was + * empty. + */ + rpc SetOneTimeEcPreKeys (SetOneTimeEcPreKeysRequest) returns (SetPreKeyResponse) {} + + /** + * Uploads a new set of one-time KEM pre-keys for the authenticated device, + * clearing any previously-stored pre-keys. Note that all keys submitted via + * a single call to this method _must_ have the same identity type (i.e. if + * the first key has an ACI identity type, then all other keys in the same + * stream must also have an ACI identity type). + * + * This method returns a single status code; if all keys were validated/stored + * successfully, then this method will return `SET_PRE_KEY_STATUS_OK`. If one + * or more keys could not be stored, this method will return a status code + * indicating the reason. If multiple keys had problems, which status code + * will be returned is not defined. + * + * This RPC may fail with an `INVALID_ARGUMENT` status if one or more of the + * given pre-keys was structurally invalid, had an invalid signature, or if + * the list of pre-keys was empty. + */ + rpc SetOneTimeKemSignedPreKeys (SetOneTimeKemSignedPreKeysRequest) returns (SetPreKeyResponse) {} + + /** + * Sets the signed EC pre-key for one identity (i.e. ACI or PNI) associated + * with the authenticated device. + * + * This RPC may fail with an `INVALID_ARGUMENT` status if the given pre-key + * was structurally invalid, had a bad signature, or was missing entirely. + */ + rpc SetEcSignedPreKey (SetEcSignedPreKeyRequest) returns (SetPreKeyResponse) {} + + /** + * Sets the last-resort KEM pre-key for one identity (i.e. ACI or PNI) + * associated with the authenticated device. + * + * This RPC may fail with an `INVALID_ARGUMENT` status if the given pre-key + * was structurally invalid, had a bad signature, or was missing entirely. + */ + rpc SetKemLastResortPreKey (SetKemLastResortPreKeyRequest) returns (SetPreKeyResponse) {} +} + +/** + * Provides methods for working with pre-keys using "unidentified access" + * credentials. + */ +service KeysAnonymous { + + /** + * Retrieves a set of pre-keys for establishing a session with the targeted + * device or devices. Callers must not submit any self-identifying credentials + * when calling this method and must instead present the targeted account's + * unidentified access key as an anonymous authentication mechanism. Callers + * without an unidentified access key should use the equivalent, authenticated + * method in `Keys` instead. + * + * This RPC may fail with an `UNAUTHENTICATED` status if the given + * unidentified access key did not match the target account's unidentified + * access key or if the target account was not found. It may also fail with a + * `NOT_FOUND` status if no active device with the given ID (if specified) was + * found on the target account, or if the target account has no active + * devices. + */ + rpc GetPreKeys(GetPreKeysAnonymousRequest) returns (GetPreKeysResponse) {} +} + +message GetPreKeyCountRequest { +} + +message GetPreKeyCountResponse { + /** + * The approximate number of one-time EC pre-keys stored for the + * authenticated device and associated with the caller's ACI. + */ + uint32 aci_ec_pre_key_count = 1; + + /** + * The approximate number of one-time Kyber pre-keys stored for the + * authenticated device and associated with the caller's ACI. + */ + uint32 aci_kem_pre_key_count = 2; + + /** + * The approximate number of one-time EC pre-keys stored for the + * authenticated device and associated with the caller's PNI. + */ + uint32 pni_ec_pre_key_count = 3; + + /** + * The approximate number of one-time KEM pre-keys stored for the + * authenticated device and associated with the caller's PNI. + */ + uint32 pni_kem_pre_key_count = 4; +} + +message GetPreKeysRequest { + /** + * The service identifier of the account for which to retrieve pre-keys. + */ + common.ServiceIdentifier target_identifier = 1; + + /** + * The ID of the device associated with the targeted account for which to + * retrieve pre-keys. If not set, pre-keys are returned for all devices + * associated with the targeted account. + */ + uint64 device_id = 2; +} + +message GetPreKeysAnonymousRequest { + /** + * The service identifier of the account for which to retrieve pre-keys. + */ + common.ServiceIdentifier target_identifier = 1; + + /** + * The ID of the device associated with the targeted account for which to + * retrieve pre-keys. If not set, pre-keys are returned for all devices + * associated with the targeted account. + */ + uint64 device_id = 2; + + /** + * The unidentified access key (UAK) for the targeted account. + */ + bytes unidentified_access_key = 3; +} + +message GetPreKeysResponse { + message PreKeyBundle { + /** + * The EC signed pre-key associated with the targeted + * account/device/identity. + */ + common.EcSignedPreKey ec_signed_pre_key = 1; + + /** + * A one-time EC pre-key for the targeted account/device/identity. May not + * be set if no one-time EC pre-keys are available. + */ + common.EcPreKey ec_one_time_pre_key = 2; + + /** + * A one-time KEM pre-key (or a last-resort KEM pre-key) for the targeted + * account/device/identity. May not be set if the targeted device has not + * yet uploaded any KEM pre-keys. + */ + common.KemSignedPreKey kem_one_time_pre_key = 3; + } + + /** + * The identity key associated with the targeted account/identity. + */ + bytes identity_key = 1; + + /** + * A map of device IDs to pre-key "bundles" for the targeted account. + */ + map pre_keys = 2; +} + +message SetOneTimeEcPreKeysRequest { + /** + * The identity type (i.e. ACI/PNI) with which the keys in this request are + * associated. + */ + common.IdentityType identity_type = 1; + + /** + * The unsigned EC pre-keys to be stored. + */ + repeated common.EcPreKey pre_keys = 2; +} + +message SetOneTimeKemSignedPreKeysRequest { + /** + * The identity type (i.e. ACI/PNI) with which the keys in this request are + * associated. + */ + common.IdentityType identity_type = 1; + + /** + * The KEM pre-keys to be stored. + */ + repeated common.KemSignedPreKey pre_keys = 2; +} + +message SetEcSignedPreKeyRequest { + /** + * The identity type (i.e. ACI/PNI) with which this key is associated. + */ + common.IdentityType identity_type = 1; + + /** + * The signed EC pre-key itself. + */ + common.EcSignedPreKey signed_pre_key = 2; +} + +message SetKemLastResortPreKeyRequest { + /** + * The identity type (i.e. ACI/PNI) with which this key is associated. + */ + common.IdentityType identity_type = 1; + + /** + * The signed KEM pre-key itself. + */ + common.KemSignedPreKey signed_pre_key = 2; +} + +message SetPreKeyResponse { +} + diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java new file mode 100644 index 000000000..993974b07 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.security.SecureRandom; +import java.util.Optional; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.storage.Account; + +class UnidentifiedAccessUtilTest { + + @ParameterizedTest + @MethodSource + void checkUnidentifiedAccess(@Nullable final byte[] targetUak, + final boolean unrestrictedUnidentifiedAccess, + final byte[] presentedUak, + final boolean expectAccessAllowed) { + + final Account account = mock(Account.class); + when(account.getUnidentifiedAccessKey()).thenReturn(Optional.ofNullable(targetUak)); + when(account.isUnrestrictedUnidentifiedAccess()).thenReturn(unrestrictedUnidentifiedAccess); + + assertEquals(expectAccessAllowed, UnidentifiedAccessUtil.checkUnidentifiedAccess(account, presentedUak)); + } + + private static Stream checkUnidentifiedAccess() { + final byte[] uak = new byte[16]; + new SecureRandom().nextBytes(uak); + + final byte[] incorrectUak = new byte[uak.length + 1]; + + return Stream.of( + Arguments.of(null, false, uak, false), + Arguments.of(null, true, uak, true), + Arguments.of(uak, false, incorrectUak, false), + Arguments.of(uak, false, uak, true), + Arguments.of(uak, true, incorrectUak, true), + Arguments.of(uak, true, uak, true) + ); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java new file mode 100644 index 000000000..a4fd52df8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth.grpc; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import java.util.UUID; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.util.Pair; + +public class MockAuthenticationInterceptor implements ServerInterceptor { + + @Nullable + private Pair authenticatedDevice; + + public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) { + authenticatedDevice = new Pair<>(accountIdentifier, deviceId); + } + + public void clearAuthenticatedDevice() { + authenticatedDevice = null; + } + + @Override + public ServerCall.Listener interceptCall(final ServerCall call, + final Metadata headers, + final ServerCallHandler next) { + + if (authenticatedDevice != null) { + final Context context = Context.current() + .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first()) + .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second()); + + return Contexts.interceptCall(context, call, headers, next); + } + + return next.startCall(call, headers); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java new file mode 100644 index 000000000..725ecaa9d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java @@ -0,0 +1,119 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.BindableService; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.util.MutableHandlerRegistry; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +// This is mostly a direct port of +// https://github.com/grpc/grpc-java/blob/master/testing/src/main/java/io/grpc/testing/GrpcServerRule.java, but for +// JUnit 5. +public class GrpcServerExtension implements BeforeEachCallback, AfterEachCallback { + + private ManagedChannel channel; + private Server server; + private String serverName; + private MutableHandlerRegistry serviceRegistry; + private boolean useDirectExecutor; + + /** + * Returns {@code this} configured to use a direct executor for the {@link ManagedChannel} and + * {@link Server}. This can only be called at the rule instantiation. + */ + public final GrpcServerExtension directExecutor() { + if (serverName != null) { + throw new IllegalStateException("directExecutor() can only be called at the rule instantiation"); + } + + useDirectExecutor = true; + return this; + } + + /** + * Returns a {@link ManagedChannel} connected to this service. + */ + public final ManagedChannel getChannel() { + return channel; + } + + /** + * Returns the underlying gRPC {@link Server} for this service. + */ + public final Server getServer() { + return server; + } + + /** + * Returns the randomly generated server name for this service. + */ + public final String getServerName() { + return serverName; + } + + /** + * Returns the service registry for this service. The registry is used to add service instances + * (e.g. {@link BindableService} or {@link ServerServiceDefinition} to the server. + */ + public final MutableHandlerRegistry getServiceRegistry() { + return serviceRegistry; + } + + @Override + public void beforeEach(final ExtensionContext extensionContext) throws Exception { + serverName = UUID.randomUUID().toString(); + serviceRegistry = new MutableHandlerRegistry(); + + final InProcessServerBuilder serverBuilder = InProcessServerBuilder.forName(serverName) + .fallbackHandlerRegistry(serviceRegistry); + + if (useDirectExecutor) { + serverBuilder.directExecutor(); + } + + server = serverBuilder.build().start(); + + final InProcessChannelBuilder channelBuilder = InProcessChannelBuilder.forName(serverName); + + if (useDirectExecutor) { + channelBuilder.directExecutor(); + } + + channel = channelBuilder.build(); + } + + @Override + public void afterEach(final ExtensionContext extensionContext) throws Exception { + serverName = null; + serviceRegistry = null; + + channel.shutdown(); + server.shutdown(); + + try { + channel.awaitTermination(1, TimeUnit.MINUTES); + server.awaitTermination(1, TimeUnit.MINUTES); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } finally { + channel.shutdownNow(); + channel = null; + + server.shutdownNow(); + server = null; + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java new file mode 100644 index 000000000..81cf7e68e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java @@ -0,0 +1,211 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.security.SecureRandom; +import java.util.Collections; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.signal.chat.common.EcPreKey; +import org.signal.chat.common.EcSignedPreKey; +import org.signal.chat.common.IdentityType; +import org.signal.chat.common.KemSignedPreKey; +import org.signal.chat.common.ServiceIdentifier; +import org.signal.chat.keys.GetPreKeysAnonymousRequest; +import org.signal.chat.keys.GetPreKeysResponse; +import org.signal.chat.keys.KeysAnonymousGrpc; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.UUIDUtil; + +class KeysAnonymousGrpcServiceTest { + + private AccountsManager accountsManager; + private KeysManager keysManager; + + private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub; + + @RegisterExtension + static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); + + @BeforeEach + void setUp() { + accountsManager = mock(AccountsManager.class); + keysManager = mock(KeysManager.class); + + final KeysAnonymousGrpcService keysGrpcService = + new KeysAnonymousGrpcService(accountsManager, keysManager); + + keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); + + GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService); + } + + @Test + void getPreKeys() { + final Account targetAccount = mock(Account.class); + final Device targetDevice = mock(Device.class); + + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); + final UUID identifier = UUID.randomUUID(); + + final byte[] unidentifiedAccessKey = new byte[16]; + new SecureRandom().nextBytes(unidentifiedAccessKey); + + when(targetDevice.getId()).thenReturn(Device.MASTER_ID); + when(targetDevice.isEnabled()).thenReturn(true); + when(targetAccount.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(targetDevice)); + + when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey)); + when(targetAccount.getUuid()).thenReturn(identifier); + when(targetAccount.getIdentityKey()).thenReturn(identityKey); + when(accountsManager.getByAccountIdentifierAsync(identifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + + final ECPreKey ecPreKey = new ECPreKey(1, Curve.generateKeyPair().getPublicKey()); + final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(2, identityKeyPair); + final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(3, identityKeyPair); + + when(keysManager.takeEC(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey))); + when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey))); + when(targetDevice.getSignedPreKey()).thenReturn(ecSignedPreKey); + + final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(identifier)) + .build()) + .setDeviceId(Device.MASTER_ID) + .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) + .build()); + + final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder() + .setIdentityKey(ByteString.copyFrom(identityKey.serialize())) + .putPreKeys(Device.MASTER_ID, GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey())) + .build()) + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKey.signature())) + .build()) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemSignedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemSignedPreKey.signature())) + .build()) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + @Test + void getPreKeysIncorrectUnidentifiedAccessKey() { + final Account targetAccount = mock(Account.class); + + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); + final UUID identifier = UUID.randomUUID(); + + final byte[] unidentifiedAccessKey = new byte[16]; + new SecureRandom().nextBytes(unidentifiedAccessKey); + + when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey)); + when(targetAccount.getUuid()).thenReturn(identifier); + when(targetAccount.getIdentityKey()).thenReturn(identityKey); + when(accountsManager.getByAccountIdentifierAsync(identifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException statusRuntimeException = + assertThrows(StatusRuntimeException.class, + () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(identifier)) + .build()) + .setDeviceId(Device.MASTER_ID) + .build())); + + assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode()); + } + + @Test + void getPreKeysAccountNotFound() { + when(accountsManager.getByAccountIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID())) + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build())); + + assertEquals(Status.Code.UNAUTHENTICATED, exception.getStatus().getCode()); + } + + @ParameterizedTest + @ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1}) + void getPreKeysDeviceNotFound(final long deviceId) { + final UUID accountIdentifier = UUID.randomUUID(); + + final byte[] unidentifiedAccessKey = new byte[16]; + new SecureRandom().nextBytes(unidentifiedAccessKey); + + final Account targetAccount = mock(Account.class); + when(targetAccount.getUuid()).thenReturn(accountIdentifier); + when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey())); + when(targetAccount.getDevices()).thenReturn(Collections.emptyList()); + when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty()); + when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey)); + + when(accountsManager.getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder() + .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey)) + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(accountIdentifier)) + .build()) + .setDeviceId(deviceId) + .build())); + + assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java new file mode 100644 index 000000000..d2431a281 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java @@ -0,0 +1,678 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.signal.chat.common.EcPreKey; +import org.signal.chat.common.EcSignedPreKey; +import org.signal.chat.common.KemSignedPreKey; +import org.signal.chat.common.ServiceIdentifier; +import org.signal.chat.keys.GetPreKeyCountRequest; +import org.signal.chat.keys.GetPreKeyCountResponse; +import org.signal.chat.keys.GetPreKeysRequest; +import org.signal.chat.keys.GetPreKeysResponse; +import org.signal.chat.keys.KeysGrpc; +import org.signal.chat.keys.SetEcSignedPreKeyRequest; +import org.signal.chat.keys.SetKemLastResortPreKeyRequest; +import org.signal.chat.keys.SetOneTimeEcPreKeysRequest; +import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.UUIDUtil; +import reactor.core.publisher.Mono; + +class KeysGrpcServiceTest { + + private AccountsManager accountsManager; + private KeysManager keysManager; + private RateLimiter preKeysRateLimiter; + + private Device authenticatedDevice; + + private KeysGrpc.KeysBlockingStub keysStub; + + private static final UUID AUTHENTICATED_ACI = UUID.randomUUID(); + private static final UUID AUTHENTICATED_PNI = UUID.randomUUID(); + private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID; + + private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + + @RegisterExtension + static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension(); + + @BeforeEach + void setUp() { + accountsManager = mock(AccountsManager.class); + keysManager = mock(KeysManager.class); + preKeysRateLimiter = mock(RateLimiter.class); + + final RateLimiters rateLimiters = mock(RateLimiters.class); + when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter); + + when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty()); + + final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters); + keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel()); + + authenticatedDevice = mock(Device.class); + when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID); + + final Account authenticatedAccount = mock(Account.class); + when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI); + when(authenticatedAccount.getPhoneNumberIdentifier()).thenReturn(AUTHENTICATED_PNI); + when(authenticatedAccount.getIdentityKey()).thenReturn(new IdentityKey(ACI_IDENTITY_KEY_PAIR.getPublicKey())); + when(authenticatedAccount.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey())); + when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice)); + + final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor(); + mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID); + + GRPC_SERVER_EXTENSION.getServiceRegistry() + .addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor)); + + when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount)); + when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount)); + + when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); + when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount))); + } + + @Test + void getPreKeyCount() { + when(keysManager.getEcCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(1)); + + when(keysManager.getPqCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(2)); + + when(keysManager.getEcCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(3)); + + when(keysManager.getPqCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID)) + .thenReturn(CompletableFuture.completedFuture(4)); + + assertEquals(GetPreKeyCountResponse.newBuilder() + .setAciEcPreKeyCount(1) + .setAciKemPreKeyCount(2) + .setPniEcPreKeyCount(3) + .setPniKemPreKeyCount(4) + .build(), + keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build())); + } + + @ParameterizedTest + @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) + void setOneTimeEcPreKeys(final org.signal.chat.common.IdentityType identityType) { + final List preKeys = new ArrayList<>(); + + for (int keyId = 0; keyId < 100; keyId++) { + preKeys.add(new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey())); + } + + when(keysManager.storeEcOneTimePreKeys(any(), anyLong(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + //noinspection ResultOfMethodCallIgnored + keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder() + .setIdentityType(identityType) + .addAllPreKeys(preKeys.stream() + .map(preKey -> EcPreKey.newBuilder() + .setKeyId(preKey.keyId()) + .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey())) + .build()) + .toList()) + .build()); + + final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> AUTHENTICATED_ACI; + case PNI -> AUTHENTICATED_PNI; + }; + + verify(keysManager).storeEcOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys); + } + + @ParameterizedTest + @MethodSource + void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) { + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request)); + + assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + } + + private static Stream setOneTimeEcPreKeysWithError() { + return Stream.of( + // Missing identity type + Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder() + .addPreKeys(EcPreKey.newBuilder() + .setKeyId(1) + .setPublicKey(ByteString.copyFrom(Curve.generateKeyPair().getPublicKey().serialize())) + .build()) + .build()), + + // Invalid public key + Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .addPreKeys(EcPreKey.newBuilder() + .setKeyId(1) + .setPublicKey(ByteString.empty()) + .build()) + .build()), + + // No keys + Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .build()) + ); + } + + @ParameterizedTest + @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) + void setOneTimeKemSignedPreKeys(final org.signal.chat.common.IdentityType identityType) { + final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> ACI_IDENTITY_KEY_PAIR; + case PNI -> PNI_IDENTITY_KEY_PAIR; + }; + + final List preKeys = new ArrayList<>(); + + for (int keyId = 0; keyId < 100; keyId++) { + preKeys.add(KeysHelper.signedKEMPreKey(keyId, identityKeyPair)); + } + + when(keysManager.storeKemOneTimePreKeys(any(), anyLong(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + //noinspection ResultOfMethodCallIgnored + keysStub.setOneTimeKemSignedPreKeys( + SetOneTimeKemSignedPreKeysRequest.newBuilder() + .setIdentityType(identityType) + .addAllPreKeys(preKeys.stream() + .map(preKey -> KemSignedPreKey.newBuilder() + .setKeyId(preKey.keyId()) + .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(preKey.signature())) + .build()) + .toList()) + .build()); + + final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> AUTHENTICATED_ACI; + case PNI -> AUTHENTICATED_PNI; + }; + + verify(keysManager).storeKemOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys); + } + + @ParameterizedTest + @MethodSource + void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) { + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request)); + + assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + } + + private static Stream setOneTimeKemSignedPreKeysWithError() { + final KEMSignedPreKey signedPreKey = KeysHelper.signedKEMPreKey(1, ACI_IDENTITY_KEY_PAIR); + + return Stream.of( + // Missing identity type + Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder() + .addPreKeys(KemSignedPreKey.newBuilder() + .setKeyId(1) + .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(signedPreKey.signature())) + .build()) + .build()), + + // Invalid public key + Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .addPreKeys(KemSignedPreKey.newBuilder() + .setKeyId(1) + .setPublicKey(ByteString.empty()) + .setSignature(ByteString.copyFrom(signedPreKey.signature())) + .build()) + .build()), + + // Invalid signature + Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .addPreKeys(KemSignedPreKey.newBuilder() + .setKeyId(1) + .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey())) + .setSignature(ByteString.empty()) + .build()) + .build()), + + // No keys + Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .build()) + ); + } + + @ParameterizedTest + @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) + void setSignedPreKey(final org.signal.chat.common.IdentityType identityType) { + when(accountsManager.updateDeviceAsync(any(), anyLong(), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + final long deviceId = invocation.getArgument(1); + final Consumer deviceUpdater = invocation.getArgument(2); + + account.getDevice(deviceId).ifPresent(deviceUpdater); + + return CompletableFuture.completedFuture(account); + }); + + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> ACI_IDENTITY_KEY_PAIR; + case PNI -> PNI_IDENTITY_KEY_PAIR; + }; + + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair); + + //noinspection ResultOfMethodCallIgnored + keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder() + .setIdentityType(identityType) + .setSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(signedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(signedPreKey.signature())) + .build()) + .build()); + + switch (identityType) { + case IDENTITY_TYPE_ACI -> { + verify(authenticatedDevice).setSignedPreKey(signedPreKey); + verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey)); + } + + case IDENTITY_TYPE_PNI -> { + verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey); + verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey)); + } + } + } + + @ParameterizedTest + @MethodSource + void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) { + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request)); + + assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + } + + private static Stream setSignedPreKeyWithError() { + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, ACI_IDENTITY_KEY_PAIR); + + return Stream.of( + // Missing identity type + Arguments.of(SetEcSignedPreKeyRequest.newBuilder() + .setSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(signedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(signedPreKey.signature())) + .build()) + .build()), + + // Invalid public key + Arguments.of(SetEcSignedPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(signedPreKey.keyId()) + .setPublicKey(ByteString.empty()) + .setSignature(ByteString.copyFrom(signedPreKey.signature())) + .build()) + .build()), + + // Invalid signature + Arguments.of(SetEcSignedPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(signedPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey())) + .setSignature(ByteString.empty()) + .build()) + .build()), + + // Missing key + Arguments.of(SetEcSignedPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .build()) + ); + } + + @ParameterizedTest + @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) + void setLastResortPreKey(final org.signal.chat.common.IdentityType identityType) { + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) { + case ACI -> ACI_IDENTITY_KEY_PAIR; + case PNI -> PNI_IDENTITY_KEY_PAIR; + }; + + final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair); + + //noinspection ResultOfMethodCallIgnored + keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder() + .setIdentityType(identityType) + .setSignedPreKey(KemSignedPreKey.newBuilder() + .setKeyId(lastResortPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(lastResortPreKey.signature())) + .build()) + .build()); + + final UUID expectedIdentifier = switch (identityType) { + case IDENTITY_TYPE_ACI -> AUTHENTICATED_ACI; + case IDENTITY_TYPE_PNI -> AUTHENTICATED_PNI; + case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw new AssertionError("Bad identity type"); + }; + + verify(keysManager).storePqLastResort(expectedIdentifier, Map.of(AUTHENTICATED_DEVICE_ID, lastResortPreKey)); + } + + @ParameterizedTest + @MethodSource + void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) { + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request)); + + assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode()); + } + + private static Stream setLastResortPreKeyWithError() { + final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, ACI_IDENTITY_KEY_PAIR); + + return Stream.of( + // No identity type + Arguments.of(SetKemLastResortPreKeyRequest.newBuilder() + .setSignedPreKey(KemSignedPreKey.newBuilder() + .setKeyId(lastResortPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey())) + .setSignature(ByteString.copyFrom(lastResortPreKey.signature())) + .build()) + .build()), + + // Bad public key + Arguments.of(SetKemLastResortPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setSignedPreKey(KemSignedPreKey.newBuilder() + .setKeyId(lastResortPreKey.keyId()) + .setPublicKey(ByteString.empty()) + .setSignature(ByteString.copyFrom(lastResortPreKey.signature())) + .build()) + .build()), + + // Bad signature + Arguments.of(SetKemLastResortPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setSignedPreKey(KemSignedPreKey.newBuilder() + .setKeyId(lastResortPreKey.keyId()) + .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey())) + .setSignature(ByteString.empty()) + .build()) + .build()), + + // Missing key + Arguments.of(SetKemLastResortPreKeyRequest.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .build()) + ); + } + + @ParameterizedTest + @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) + void getPreKeys(final org.signal.chat.common.IdentityType identityType) { + final Account targetAccount = mock(Account.class); + + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); + final UUID identifier = UUID.randomUUID(); + + if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) { + when(targetAccount.getUuid()).thenReturn(identifier); + when(targetAccount.getIdentityKey()).thenReturn(identityKey); + when(accountsManager.getByAccountIdentifierAsync(identifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + } else { + when(targetAccount.getUuid()).thenReturn(UUID.randomUUID()); + when(targetAccount.getPhoneNumberIdentifier()).thenReturn(identifier); + when(targetAccount.getPhoneNumberIdentityKey()).thenReturn(identityKey); + when(accountsManager.getByPhoneNumberIdentifierAsync(identifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + } + + final Map ecOneTimePreKeys = new HashMap<>(); + final Map kemPreKeys = new HashMap<>(); + final Map ecSignedPreKeys = new HashMap<>(); + + final Map devices = new HashMap<>(); + + for (final long deviceId : List.of(1, 2)) { + ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey())); + kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair)); + ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair)); + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + when(device.isEnabled()).thenReturn(true); + + if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) { + when(device.getSignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId)); + } else { + when(device.getPhoneNumberIdentitySignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId)); + } + + devices.put(deviceId, device); + when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device)); + } + + when(targetAccount.getDevices()).thenReturn(new ArrayList<>(devices.values())); + + ecOneTimePreKeys.forEach((deviceId, preKey) -> when(keysManager.takeEC(identifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + + kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey)))); + + { + final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(identityType) + .setUuid(UUIDUtil.toByteString(identifier)) + .build()) + .setDeviceId(1) + .build()); + + final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder() + .setIdentityKey(ByteString.copyFrom(identityKey.serialize())) + .putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature())) + .build()) + .setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecOneTimePreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey())) + .build()) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemPreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature())) + .build()) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + + when(keysManager.takeEC(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + { + final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(identityType) + .setUuid(UUIDUtil.toByteString(identifier)) + .build()) + .build()); + + final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder() + .setIdentityKey(ByteString.copyFrom(identityKey.serialize())) + .putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature())) + .build()) + .setEcOneTimePreKey(EcPreKey.newBuilder() + .setKeyId(ecOneTimePreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey())) + .build()) + .setKemOneTimePreKey(KemSignedPreKey.newBuilder() + .setKeyId(kemPreKeys.get(1L).keyId()) + .setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey())) + .setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature())) + .build()) + .build()) + .putPreKeys(2, GetPreKeysResponse.PreKeyBundle.newBuilder() + .setEcSignedPreKey(EcSignedPreKey.newBuilder() + .setKeyId(ecSignedPreKeys.get(2L).keyId()) + .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(2L).serializedPublicKey())) + .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(2L).signature())) + .build()) + .build()) + .build(); + + assertEquals(expectedResponse, response); + } + } + + @Test + void getPreKeysAccountNotFound() { + when(accountsManager.getByAccountIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build())); + + assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + } + + @ParameterizedTest + @ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1}) + void getPreKeysDeviceNotFound(final long deviceId) { + final UUID accountIdentifier = UUID.randomUUID(); + + final Account targetAccount = mock(Account.class); + when(targetAccount.getUuid()).thenReturn(accountIdentifier); + when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey())); + when(targetAccount.getDevices()).thenReturn(Collections.emptyList()); + when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty()); + + when(accountsManager.getByAccountIdentifierAsync(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(accountIdentifier)) + .build()) + .setDeviceId(deviceId) + .build())); + + assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode()); + } + + @Test + void getPreKeysRateLimited() { + final Account targetAccount = mock(Account.class); + when(targetAccount.getUuid()).thenReturn(UUID.randomUUID()); + when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey())); + when(targetAccount.getDevices()).thenReturn(Collections.emptyList()); + when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty()); + + when(accountsManager.getByAccountIdentifierAsync(any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); + + final Duration retryAfterDuration = Duration.ofMinutes(7); + + when(preKeysRateLimiter.validateReactive(anyString())) + .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false))); + + @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception = + assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder() + .setTargetIdentifier(ServiceIdentifier.newBuilder() + .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) + .setUuid(UUIDUtil.toByteString(UUID.randomUUID())) + .build()) + .build())); + + assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode()); + assertNotNull(exception.getTrailers()); + assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY)); + } +}