From 5627209fdd44b4d235c31bb9b1ada8ecfcfb2290 Mon Sep 17 00:00:00 2001
From: Jon Chambers <63609320+jon-signal@users.noreply.github.com>
Date: Thu, 20 Jul 2023 11:10:26 -0400
Subject: [PATCH] Add a gRPC service for working with pre-keys
---
pom.xml | 16 +
service/pom.xml | 5 +
.../textsecuregcm/WhisperServerService.java | 14 +-
.../auth/UnidentifiedAccessUtil.java | 32 +
.../auth/grpc/AuthenticatedDevice.java | 11 +
.../auth/grpc/AuthenticationUtil.java | 28 +-
...icCredentialAuthenticationInterceptor.java | 5 +-
.../auth/grpc/NotAuthenticatedException.java | 13 -
.../textsecuregcm/grpc/IdentityType.java | 21 +
.../grpc/KeysAnonymousGrpcService.java | 40 ++
.../textsecuregcm/grpc/KeysGrpcHelper.java | 107 +++
.../textsecuregcm/grpc/KeysGrpcService.java | 307 ++++++++
.../textsecuregcm/grpc/RateLimitUtil.java | 45 ++
.../textsecuregcm/limits/RateLimiter.java | 5 +
.../textsecuregcm/storage/KeysManager.java | 8 +
.../util/NoStackTraceRuntimeException.java | 29 +
.../textsecuregcm/util/UUIDUtil.java | 9 +
.../main/proto/org/signal/chat/common.proto | 71 ++
.../src/main/proto/org/signal/chat/keys.proto | 263 +++++++
.../auth/UnidentifiedAccessUtilTest.java | 52 ++
.../grpc/MockAuthenticationInterceptor.java | 46 ++
.../grpc/GrpcServerExtension.java | 119 +++
.../grpc/KeysAnonymousGrpcServiceTest.java | 211 ++++++
.../grpc/KeysGrpcServiceTest.java | 678 ++++++++++++++++++
24 files changed, 2112 insertions(+), 23 deletions(-)
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java
delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java
create mode 100644 service/src/main/proto/org/signal/chat/common.proto
create mode 100644 service/src/main/proto/org/signal/chat/keys.proto
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
diff --git a/pom.xml b/pom.xml
index a299fa84c..7610e7e4d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -62,6 +62,7 @@
1.2.0
3.21.7
0.15.2
+ 1.2.4
1.7.0
3.1.0
1.7.30
@@ -124,6 +125,11 @@
pom
import
+
+ com.salesforce.servicelibs
+ reactor-grpc-stub
+ ${reactive.grpc.version}
+
io.github.resilience4j
resilience4j-bom
@@ -398,6 +404,16 @@
com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
grpc-java
io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier}
+
+
+
+ reactor-grpc
+ com.salesforce.servicelibs
+ reactor-grpc
+ ${reactive.grpc.version}
+ com.salesforce.reactorgrpc.ReactorGrpcGenerator
+
+
diff --git a/service/pom.xml b/service/pom.xml
index a6a7b35c7..9ea5159e7 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -286,6 +286,11 @@
jackson-jaxrs-json-provider
+
+ com.salesforce.servicelibs
+ reactor-grpc-stub
+
+
software.amazon.awssdk
sts
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 4a3bcf50d..2939bab4c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -24,6 +24,7 @@ import io.dropwizard.auth.basic.BasicCredentials;
import io.dropwizard.setup.Bootstrap;
import io.dropwizard.setup.Environment;
import io.grpc.ServerBuilder;
+import io.grpc.ServerInterceptors;
import io.lettuce.core.metrics.MicrometerCommandLatencyRecorder;
import io.lettuce.core.metrics.MicrometerOptions;
import io.lettuce.core.resource.ClientResources;
@@ -64,6 +65,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
+import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
@@ -72,6 +74,7 @@ import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
+import org.whispersystems.textsecuregcm.auth.grpc.BasicCredentialAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.badges.ConfiguredProfileBadgeConverter;
import org.whispersystems.textsecuregcm.badges.ResourceBundleLevelTranslator;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
@@ -115,6 +118,8 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
+import org.whispersystems.textsecuregcm.grpc.KeysGrpcService;
+import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -401,6 +406,7 @@ public class WhisperServerService extends Application disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator(
disabledPermittedAccountAuthenticator).buildAuthFilter();
+ final BasicCredentialAuthenticationInterceptor basicCredentialAuthenticationInterceptor =
+ new BasicCredentialAuthenticationInterceptor(new BaseAccountAuthenticator(accountsManager));
+
final ServerBuilder> grpcServer = ServerBuilder.forPort(config.getGrpcPort())
- .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry)); /* TODO: specialize metrics with user-agent platform */
+ // TODO: specialize metrics with user-agent platform
+ .intercept(new MetricCollectingServerInterceptor(Metrics.globalRegistry))
+ .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
+ .addService(new KeysAnonymousGrpcService(accountsManager, keys));
RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
environment.servlets()
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java
new file mode 100644
index 000000000..32283170a
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtil.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth;
+
+import org.whispersystems.textsecuregcm.storage.Account;
+import java.security.MessageDigest;
+
+public class UnidentifiedAccessUtil {
+
+ private UnidentifiedAccessUtil() {
+ }
+
+ /**
+ * Checks whether an action (e.g. sending a message or retrieving pre-keys) may be taken on the target account by an
+ * actor presenting the given unidentified access key.
+ *
+ * @param targetAccount the account on which an actor wishes to take an action
+ * @param unidentifiedAccessKey the unidentified access key presented by the actor
+ *
+ * @return {@code true} if an actor presenting the given unidentified access key has permission to take an action on
+ * the target account or {@code false} otherwise
+ */
+ public static boolean checkUnidentifiedAccess(final Account targetAccount, final byte[] unidentifiedAccessKey) {
+ return targetAccount.isUnrestrictedUnidentifiedAccess()
+ || targetAccount.getUnidentifiedAccessKey()
+ .map(targetUnidentifiedAccessKey -> MessageDigest.isEqual(targetUnidentifiedAccessKey, unidentifiedAccessKey))
+ .orElse(false);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java
new file mode 100644
index 000000000..906056986
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java
@@ -0,0 +1,11 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth.grpc;
+
+import java.util.UUID;
+
+public record AuthenticatedDevice(UUID accountIdentifier, long deviceId) {
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java
index b68f0cf4e..a449b2092 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java
@@ -6,8 +6,12 @@
package org.whispersystems.textsecuregcm.auth.grpc;
import io.grpc.Context;
-import java.util.Optional;
+import io.grpc.Status;
import java.util.UUID;
+import javax.annotation.Nullable;
+import reactor.core.publisher.Mono;
+import reactor.util.function.Tuple2;
+import reactor.util.function.Tuples;
/**
* Provides utility methods for working with authentication in the context of gRPC calls.
@@ -17,11 +21,23 @@ public class AuthenticationUtil {
static final Context.Key CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci");
static final Context.Key CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
- public static Optional getAuthenticatedAccountIdentifier() {
- return Optional.ofNullable(CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get());
- }
+ /**
+ * Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if
+ * no authenticated account/device is available.
+ *
+ * @return the account/device authenticated in the current gRPC context
+ *
+ * @throws io.grpc.StatusRuntimeException with a status of {@code UNAUTHENTICATED} if no authenticated account/device
+ * could be retrieved from the current gRPC context
+ */
+ public static AuthenticatedDevice requireAuthenticatedDevice() {
+ @Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get();
+ @Nullable final Long deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
- public static Optional getAuthenticatedDeviceIdentifier() {
- return Optional.ofNullable(CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get());
+ if (accountIdentifier != null && deviceId != null) {
+ return new AuthenticatedDevice(accountIdentifier, deviceId);
+ }
+
+ throw Status.UNAUTHENTICATED.asRuntimeException();
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java
index 307c00092..73a1df2b0 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java
@@ -24,9 +24,8 @@ import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
* Callers supply credentials by providing a username (UUID and optional device ID) and password pair in the
* {@code x-signal-basic-auth-credentials} call header.
*
- * Downstream services can retrieve the identity of the authenticated caller using
- * {@link AuthenticationUtil#getAuthenticatedAccountIdentifier()} and
- * {@link AuthenticationUtil#getAuthenticatedDeviceIdentifier()}.
+ * Downstream services can retrieve the identity of the authenticated caller using methods in
+ * {@link AuthenticationUtil}.
*
* Note that this authentication, while fully functional, is intended only for development and testing purposes and is
* intended to be replaced with a more robust and efficient strategy before widespread client adoption.
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java
deleted file mode 100644
index 67e596476..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/NotAuthenticatedException.java
+++ /dev/null
@@ -1,13 +0,0 @@
-/*
- * Copyright 2023 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.auth.grpc;
-
-/**
- * Indicates that a caller tried to get information about the authenticated gRPC caller, but no caller has been
- * authenticated.
- */
-public class NotAuthenticatedException extends Exception {
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java
new file mode 100644
index 000000000..d558348bf
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/IdentityType.java
@@ -0,0 +1,21 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.Status;
+
+public enum IdentityType {
+ ACI,
+ PNI;
+
+ public static IdentityType fromGrpcIdentityType(final org.signal.chat.common.IdentityType grpcIdentityType) {
+ return switch (grpcIdentityType) {
+ case IDENTITY_TYPE_ACI -> ACI;
+ case IDENTITY_TYPE_PNI -> PNI;
+ case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw Status.INVALID_ARGUMENT.asRuntimeException();
+ };
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java
new file mode 100644
index 000000000..bc6410ea2
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.Status;
+import org.signal.chat.keys.GetPreKeysAnonymousRequest;
+import org.signal.chat.keys.GetPreKeysResponse;
+import org.signal.chat.keys.ReactorKeysAnonymousGrpc;
+import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.KeysManager;
+import reactor.core.publisher.Mono;
+
+public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnonymousImplBase {
+
+ private final AccountsManager accountsManager;
+ private final KeysManager keysManager;
+
+ public KeysAnonymousGrpcService(final AccountsManager accountsManager, final KeysManager keysManager) {
+ this.accountsManager = accountsManager;
+ this.keysManager = keysManager;
+ }
+
+ @Override
+ public Mono getPreKeys(final GetPreKeysAnonymousRequest request) {
+ return KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager)
+ .switchIfEmpty(Mono.error(Status.UNAUTHENTICATED.asException()))
+ .flatMap(targetAccount -> {
+ final IdentityType identityType =
+ IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType());
+
+ return UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
+ ? KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager)
+ : Mono.error(Status.UNAUTHENTICATED.asException());
+ });
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java
new file mode 100644
index 000000000..948e1aa73
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.protobuf.ByteString;
+import io.grpc.Status;
+import java.util.UUID;
+import org.signal.chat.common.EcPreKey;
+import org.signal.chat.common.EcSignedPreKey;
+import org.signal.chat.common.KemSignedPreKey;
+import org.signal.chat.common.ServiceIdentifier;
+import org.signal.chat.keys.GetPreKeysResponse;
+import org.signal.libsignal.protocol.IdentityKey;
+import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.KeysManager;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.util.function.Tuple2;
+import reactor.util.function.Tuples;
+
+class KeysGrpcHelper {
+
+ @VisibleForTesting
+ static final long ALL_DEVICES = 0;
+
+ static Mono findAccount(final ServiceIdentifier targetIdentifier, final AccountsManager accountsManager) {
+
+ return Mono.just(IdentityType.fromGrpcIdentityType(targetIdentifier.getIdentityType()))
+ .flatMap(identityType -> {
+ final UUID uuid = UUIDUtil.fromByteString(targetIdentifier.getUuid());
+
+ return Mono.fromFuture(switch (identityType) {
+ case ACI -> accountsManager.getByAccountIdentifierAsync(uuid);
+ case PNI -> accountsManager.getByPhoneNumberIdentifierAsync(uuid);
+ });
+ })
+ .flatMap(Mono::justOrEmpty)
+ .onErrorMap(IllegalArgumentException.class, throwable -> Status.INVALID_ARGUMENT.asException());
+ }
+
+ static Tuple2 getIdentifierAndIdentityKey(final Account account, final IdentityType identityType) {
+ final UUID identifier = switch (identityType) {
+ case ACI -> account.getUuid();
+ case PNI -> account.getPhoneNumberIdentifier();
+ };
+
+ final IdentityKey identityKey = switch (identityType) {
+ case ACI -> account.getIdentityKey();
+ case PNI -> account.getPhoneNumberIdentityKey();
+ };
+
+ return Tuples.of(identifier, identityKey);
+ }
+
+ static Mono getPreKeys(final Account targetAccount, final IdentityType identityType, final long targetDeviceId, final KeysManager keysManager) {
+ final Tuple2 identifierAndIdentityKey = getIdentifierAndIdentityKey(targetAccount, identityType);
+
+ final Flux devices = targetDeviceId == ALL_DEVICES
+ ? Flux.fromIterable(targetAccount.getDevices())
+ : Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));
+
+ return devices
+ .filter(Device::isEnabled)
+ .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
+ .flatMap(device -> Mono.zip(Mono.fromFuture(keysManager.takeEC(identifierAndIdentityKey.getT1(), device.getId())),
+ Mono.fromFuture(keysManager.takePQ(identifierAndIdentityKey.getT1(), device.getId())))
+ .map(oneTimePreKeys -> {
+ final ECSignedPreKey ecSignedPreKey = switch (identityType) {
+ case ACI -> device.getSignedPreKey();
+ case PNI -> device.getPhoneNumberIdentitySignedPreKey();
+ };
+
+ final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder()
+ .setEcSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(ecSignedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
+ .build());
+
+ oneTimePreKeys.getT1().ifPresent(ecPreKey -> preKeyBundleBuilder.setEcOneTimePreKey(EcPreKey.newBuilder()
+ .setKeyId(ecPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
+ .build()));
+
+ oneTimePreKeys.getT2().ifPresent(kemSignedPreKey -> preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(kemSignedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
+ .build()));
+
+ return Tuples.of(device.getId(), preKeyBundleBuilder.build());
+ }))
+ .collectMap(Tuple2::getT1, Tuple2::getT2)
+ .map(preKeyBundles -> GetPreKeysResponse.newBuilder()
+ .setIdentityKey(ByteString.copyFrom(identifierAndIdentityKey.getT2().serialize()))
+ .putAllPreKeys(preKeyBundles)
+ .build());
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
new file mode 100644
index 000000000..f09abb2b8
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java
@@ -0,0 +1,307 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import static org.whispersystems.textsecuregcm.grpc.IdentityType.ACI;
+
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+import org.signal.chat.common.EcPreKey;
+import org.signal.chat.common.EcSignedPreKey;
+import org.signal.chat.common.KemSignedPreKey;
+import org.signal.chat.keys.GetPreKeyCountRequest;
+import org.signal.chat.keys.GetPreKeyCountResponse;
+import org.signal.chat.keys.GetPreKeysRequest;
+import org.signal.chat.keys.GetPreKeysResponse;
+import org.signal.chat.keys.ReactorKeysGrpc;
+import org.signal.chat.keys.SetEcSignedPreKeyRequest;
+import org.signal.chat.keys.SetKemLastResortPreKeyRequest;
+import org.signal.chat.keys.SetOneTimeEcPreKeysRequest;
+import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
+import org.signal.chat.keys.SetPreKeyResponse;
+import org.signal.libsignal.protocol.IdentityKey;
+import org.signal.libsignal.protocol.InvalidKeyException;
+import org.signal.libsignal.protocol.ecc.ECPublicKey;
+import org.signal.libsignal.protocol.kem.KEMPublicKey;
+import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
+import org.whispersystems.textsecuregcm.auth.grpc.AuthenticationUtil;
+import org.whispersystems.textsecuregcm.entities.ECPreKey;
+import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
+import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
+import org.whispersystems.textsecuregcm.limits.RateLimiters;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.KeysManager;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+import reactor.util.function.Tuple2;
+import reactor.util.function.Tuples;
+
+public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
+
+ private final AccountsManager accountsManager;
+ private final KeysManager keysManager;
+ private final RateLimiters rateLimiters;
+
+ private static final StatusRuntimeException INVALID_PUBLIC_KEY_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT)
+ .withDescription("Invalid public key")
+ .asRuntimeException();
+
+ private static final StatusRuntimeException INVALID_SIGNATURE_EXCEPTION = Status.fromCode(Status.Code.INVALID_ARGUMENT)
+ .withDescription("Invalid signature")
+ .asRuntimeException();
+
+ private enum PreKeyType {
+ EC,
+ KEM
+ }
+
+ public KeysGrpcService(final AccountsManager accountsManager,
+ final KeysManager keysManager,
+ final RateLimiters rateLimiters) {
+
+ this.accountsManager = accountsManager;
+ this.keysManager = keysManager;
+ this.rateLimiters = rateLimiters;
+ }
+
+ @Override
+ protected Throwable onErrorMap(final Throwable throwable) {
+ return RateLimitUtil.mapRateLimitExceededException(throwable);
+ }
+
+ @Override
+ public Mono getPreKeyCount(final GetPreKeyCountRequest request) {
+ return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
+ .flatMap(authenticatedDevice -> Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier()))
+ .map(maybeAccount -> maybeAccount
+ .map(account -> Tuples.of(account, authenticatedDevice.deviceId()))
+ .orElseThrow(Status.UNAUTHENTICATED::asRuntimeException)))
+ .flatMapMany(accountAndDeviceId -> Flux.just(
+ Tuples.of(ACI, accountAndDeviceId.getT1().getUuid(), accountAndDeviceId.getT2()),
+ Tuples.of(IdentityType.PNI, accountAndDeviceId.getT1().getPhoneNumberIdentifier(), accountAndDeviceId.getT2())
+ ))
+ .flatMap(identityTypeUuidAndDeviceId -> Flux.merge(
+ Mono.fromFuture(keysManager.getEcCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3()))
+ .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.EC, ecKeyCount)),
+
+ Mono.fromFuture(keysManager.getPqCount(identityTypeUuidAndDeviceId.getT2(), identityTypeUuidAndDeviceId.getT3()))
+ .map(ecKeyCount -> Tuples.of(identityTypeUuidAndDeviceId.getT1(), PreKeyType.KEM, ecKeyCount))
+ ))
+ .reduce(GetPreKeyCountResponse.newBuilder(), (builder, tuple) -> {
+ final IdentityType identityType = tuple.getT1();
+ final PreKeyType preKeyType = tuple.getT2();
+ final int count = tuple.getT3();
+
+ switch (identityType) {
+ case ACI -> {
+ switch (preKeyType) {
+ case EC -> builder.setAciEcPreKeyCount(count);
+ case KEM -> builder.setAciKemPreKeyCount(count);
+ }
+ }
+ case PNI -> {
+ switch (preKeyType) {
+ case EC -> builder.setPniEcPreKeyCount(count);
+ case KEM -> builder.setPniKemPreKeyCount(count);
+ }
+ }
+ }
+
+ return builder;
+ })
+ .map(GetPreKeyCountResponse.Builder::build);
+ }
+
+ @Override
+ public Mono getPreKeys(final GetPreKeysRequest request) {
+ final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
+
+ final String rateLimitKey;
+ {
+ final UUID targetUuid;
+
+ try {
+ targetUuid = UUIDUtil.fromByteString(request.getTargetIdentifier().getUuid());
+ } catch (final IllegalArgumentException e) {
+ throw Status.INVALID_ARGUMENT.asRuntimeException();
+ }
+
+ rateLimitKey = authenticatedDevice.accountIdentifier() + "." +
+ authenticatedDevice.deviceId() + "__" +
+ targetUuid + "." +
+ request.getDeviceId();
+ }
+
+ return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey)
+ .then(KeysGrpcHelper.findAccount(request.getTargetIdentifier(), accountsManager))
+ .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
+ .flatMap(targetAccount -> {
+ final IdentityType identityType =
+ IdentityType.fromGrpcIdentityType(request.getTargetIdentifier().getIdentityType());
+
+ return KeysGrpcHelper.getPreKeys(targetAccount, identityType, request.getDeviceId(), keysManager);
+ });
+ }
+
+ @Override
+ public Mono setOneTimeEcPreKeys(final SetOneTimeEcPreKeysRequest request) {
+ return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
+ .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(),
+ request.getPreKeysList(),
+ IdentityType.fromGrpcIdentityType(request.getIdentityType()),
+ (requestPreKey, ignored) -> checkEcPreKey(requestPreKey),
+ (identifier, preKeys) -> keysManager.storeEcOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)));
+ }
+
+ @Override
+ public Mono setOneTimeKemSignedPreKeys(final SetOneTimeKemSignedPreKeysRequest request) {
+ return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
+ .flatMap(authenticatedDevice -> storeOneTimePreKeys(authenticatedDevice.accountIdentifier(),
+ request.getPreKeysList(),
+ IdentityType.fromGrpcIdentityType(request.getIdentityType()),
+ KeysGrpcService::checkKemSignedPreKey,
+ (identifier, preKeys) -> keysManager.storeKemOneTimePreKeys(identifier, authenticatedDevice.deviceId(), preKeys)));
+ }
+
+ private Mono storeOneTimePreKeys(final UUID authenticatedAccountUuid,
+ final List requestPreKeys,
+ final IdentityType identityType,
+ final BiFunction extractPreKeyFunction,
+ final BiFunction, CompletableFuture> storeKeysFunction) {
+
+ return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid))
+ .map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
+ .map(account -> {
+ final Tuple2 identifierAndIdentityKey =
+ KeysGrpcHelper.getIdentifierAndIdentityKey(account, identityType);
+
+ final List preKeys = requestPreKeys.stream()
+ .map(requestPreKey -> extractPreKeyFunction.apply(requestPreKey, identifierAndIdentityKey.getT2()))
+ .toList();
+
+ if (preKeys.isEmpty()) {
+ throw Status.INVALID_ARGUMENT.asRuntimeException();
+ }
+
+ return Tuples.of(identifierAndIdentityKey.getT1(), preKeys);
+ })
+ .flatMap(identifierAndPreKeys -> Mono.fromFuture(storeKeysFunction.apply(identifierAndPreKeys.getT1(), identifierAndPreKeys.getT2())))
+ .thenReturn(SetPreKeyResponse.newBuilder().build());
+ }
+
+ @Override
+ public Mono setEcSignedPreKey(final SetEcSignedPreKeyRequest request) {
+ return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
+ .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(),
+ request.getIdentityType(),
+ request.getSignedPreKey(),
+ KeysGrpcService::checkEcSignedPreKey,
+ (account, signedPreKey) -> {
+ final Consumer deviceUpdater = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
+ case ACI -> device -> device.setSignedPreKey(signedPreKey);
+ case PNI -> device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
+ };
+
+ final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
+ case ACI -> account.getUuid();
+ case PNI -> account.getPhoneNumberIdentifier();
+ };
+
+ return Flux.merge(
+ Mono.fromFuture(keysManager.storeEcSignedPreKeys(identifier, Map.of(authenticatedDevice.deviceId(), signedPreKey))),
+ Mono.fromFuture(accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater)))
+ .then();
+ }));
+ }
+
+ @Override
+ public Mono setKemLastResortPreKey(final SetKemLastResortPreKeyRequest request) {
+ return Mono.fromSupplier(AuthenticationUtil::requireAuthenticatedDevice)
+ .flatMap(authenticatedDevice -> storeRepeatedUseKey(authenticatedDevice.accountIdentifier(),
+ request.getIdentityType(),
+ request.getSignedPreKey(),
+ KeysGrpcService::checkKemSignedPreKey,
+ (account, lastResortKey) -> {
+ final UUID identifier = switch (IdentityType.fromGrpcIdentityType(request.getIdentityType())) {
+ case ACI -> account.getUuid();
+ case PNI -> account.getPhoneNumberIdentifier();
+ };
+
+ return Mono.fromFuture(keysManager.storePqLastResort(identifier, Map.of(authenticatedDevice.deviceId(), lastResortKey)));
+ }));
+ }
+
+ private Mono storeRepeatedUseKey(final UUID authenticatedAccountUuid,
+ final org.signal.chat.common.IdentityType identityType,
+ final R storeKeyRequest,
+ final BiFunction extractKeyFunction,
+ final BiFunction> storeKeyFunction) {
+
+ return Mono.fromFuture(accountsManager.getByAccountIdentifierAsync(authenticatedAccountUuid))
+ .map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
+ .map(account -> {
+ final IdentityKey identityKey = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> account.getIdentityKey();
+ case PNI -> account.getPhoneNumberIdentityKey();
+ };
+
+ final K key = extractKeyFunction.apply(storeKeyRequest, identityKey);
+
+ return Tuples.of(account, key);
+ })
+ .flatMap(accountAndKey -> storeKeyFunction.apply(accountAndKey.getT1(), accountAndKey.getT2()))
+ .thenReturn(SetPreKeyResponse.newBuilder().build());
+ }
+
+ private static ECPreKey checkEcPreKey(final EcPreKey preKey) {
+ try {
+ return new ECPreKey(preKey.getKeyId(), new ECPublicKey(preKey.getPublicKey().toByteArray()));
+ } catch (final InvalidKeyException e) {
+ throw INVALID_PUBLIC_KEY_EXCEPTION;
+ }
+ }
+
+ private static ECSignedPreKey checkEcSignedPreKey(final EcSignedPreKey preKey, final IdentityKey identityKey) {
+ try {
+ final ECSignedPreKey ecSignedPreKey = new ECSignedPreKey(preKey.getKeyId(),
+ new ECPublicKey(preKey.getPublicKey().toByteArray()),
+ preKey.getSignature().toByteArray());
+
+ if (ecSignedPreKey.signatureValid(identityKey)) {
+ return ecSignedPreKey;
+ } else {
+ throw INVALID_SIGNATURE_EXCEPTION;
+ }
+ } catch (final InvalidKeyException e) {
+ throw INVALID_PUBLIC_KEY_EXCEPTION;
+ }
+ }
+
+ private static KEMSignedPreKey checkKemSignedPreKey(final KemSignedPreKey preKey, final IdentityKey identityKey) {
+ try {
+ final KEMSignedPreKey kemSignedPreKey = new KEMSignedPreKey(preKey.getKeyId(),
+ new KEMPublicKey(preKey.getPublicKey().toByteArray()),
+ preKey.getSignature().toByteArray());
+
+ if (kemSignedPreKey.signatureValid(identityKey)) {
+ return kemSignedPreKey;
+ } else {
+ throw INVALID_SIGNATURE_EXCEPTION;
+ }
+ } catch (final InvalidKeyException e) {
+ throw INVALID_PUBLIC_KEY_EXCEPTION;
+ }
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
new file mode 100644
index 000000000..dadfd5417
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/RateLimitUtil.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.StatusException;
+import java.time.Duration;
+import javax.annotation.Nullable;
+import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+
+public class RateLimitUtil {
+
+ public static final Metadata.Key RETRY_AFTER_DURATION_KEY =
+ Metadata.Key.of("retry-after", new Metadata.AsciiMarshaller<>() {
+ @Override
+ public String toAsciiString(final Duration value) {
+ return value.toString();
+ }
+
+ @Override
+ public Duration parseAsciiString(final String serialized) {
+ return Duration.parse(serialized);
+ }
+ });
+
+ public static Throwable mapRateLimitExceededException(final Throwable throwable) {
+ if (throwable instanceof RateLimitExceededException rateLimitExceededException) {
+ @Nullable final Metadata trailers = rateLimitExceededException.getRetryDuration()
+ .map(duration -> {
+ final Metadata metadata = new Metadata();
+ metadata.put(RETRY_AFTER_DURATION_KEY, duration);
+
+ return metadata;
+ }).orElse(null);
+
+ return new StatusException(Status.RESOURCE_EXHAUSTED, trailers);
+ }
+
+ return throwable;
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
index 70e7d1cd3..b53df9a1c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
@@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import java.util.UUID;
import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import reactor.core.publisher.Mono;
public interface RateLimiter {
@@ -53,6 +54,10 @@ public interface RateLimiter {
return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
+ default Mono validateReactive(final String key) {
+ return Mono.fromFuture(validateAsync(key).toCompletableFuture());
+ }
+
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits);
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java
index 90e0ea8c0..65167ceb3 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java
@@ -90,6 +90,14 @@ public class KeysManager {
return pqLastResortKeys.store(identifier, keys);
}
+ public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) {
+ return ecPreKeys.store(identifier, deviceId, preKeys);
+ }
+
+ public CompletableFuture storeKemOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) {
+ return pqPreKeys.store(identifier, deviceId, preKeys);
+ }
+
public CompletableFuture> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId);
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java
new file mode 100644
index 000000000..1fa8c6824
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/NoStackTraceRuntimeException.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.util;
+
+/**
+ * An abstract base class for runtime exceptions that do not include a stack trace. Stackless exceptions are generally
+ * intended for internal error-handling cases where the error will never be logged or otherwise reported.
+ */
+public abstract class NoStackTraceRuntimeException extends RuntimeException {
+
+ public NoStackTraceRuntimeException() {
+ super(null, null, true, false);
+ }
+
+ public NoStackTraceRuntimeException(final String message) {
+ super(message, null, true, false);
+ }
+
+ public NoStackTraceRuntimeException(final String message, final Throwable cause) {
+ super(message, cause, true, false);
+ }
+
+ public NoStackTraceRuntimeException(final Throwable cause) {
+ super(null, cause, true, false);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
index 3f1d91d7c..38f2cb34b 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UUIDUtil.java
@@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.util;
+import com.google.protobuf.ByteString;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.util.Optional;
@@ -27,6 +28,14 @@ public final class UUIDUtil {
return byteBuffer.flip();
}
+ public static ByteString toByteString(final UUID uuid) {
+ return ByteString.copyFrom(toByteBuffer(uuid));
+ }
+
+ public static UUID fromByteString(final ByteString byteString) {
+ return fromBytes(byteString.toByteArray());
+ }
+
public static UUID fromBytes(final byte[] bytes) {
return fromByteBuffer(ByteBuffer.wrap(bytes));
}
diff --git a/service/src/main/proto/org/signal/chat/common.proto b/service/src/main/proto/org/signal/chat/common.proto
new file mode 100644
index 000000000..5add25f6c
--- /dev/null
+++ b/service/src/main/proto/org/signal/chat/common.proto
@@ -0,0 +1,71 @@
+syntax = "proto3";
+
+option java_multiple_files = true;
+
+package org.signal.chat.common;
+
+enum IdentityType {
+ IDENTITY_TYPE_UNSPECIFIED = 0;
+ IDENTITY_TYPE_ACI = 1;
+ IDENTITY_TYPE_PNI = 2;
+}
+
+message ServiceIdentifier {
+ /**
+ * The type of identity represented by this service identifier.
+ */
+ IdentityType identity_type = 1;
+
+ /**
+ * The UUID of the identity represented by this service identifier.
+ */
+ bytes uuid = 2;
+}
+
+message EcPreKey {
+ /**
+ * A locally-unique identifier for this key.
+ */
+ uint64 key_id = 1;
+
+ /**
+ * The serialized form of the public key.
+ */
+ bytes public_key = 2;
+}
+
+message EcSignedPreKey {
+ /**
+ * A locally-unique identifier for this key.
+ */
+ uint64 key_id = 1;
+
+ /**
+ * The serialized form of the public key.
+ */
+ bytes public_key = 2;
+
+ /**
+ * A signature of the public key, verifiable with the identity key for the
+ * account/identity associated with this pre-key.
+ */
+ bytes signature = 3;
+}
+
+message KemSignedPreKey {
+ /**
+ * A locally-unique identifier for this key.
+ */
+ uint64 key_id = 1;
+
+ /**
+ * The serialized form of the public key.
+ */
+ bytes public_key = 2;
+
+ /**
+ * A signature of the public key, verifiable with the identity key for the
+ * account/identity associated with this pre-key.
+ */
+ bytes signature = 3;
+}
diff --git a/service/src/main/proto/org/signal/chat/keys.proto b/service/src/main/proto/org/signal/chat/keys.proto
new file mode 100644
index 000000000..7f7af8fb6
--- /dev/null
+++ b/service/src/main/proto/org/signal/chat/keys.proto
@@ -0,0 +1,263 @@
+syntax = "proto3";
+
+option java_multiple_files = true;
+
+package org.signal.chat.keys;
+
+import "org/signal/chat/common.proto";
+
+/**
+ * Provides methods for working with pre-keys.
+ */
+service Keys {
+
+ /**
+ * Retrieves an approximate count of the number of the various kinds of
+ * pre-keys stored for the authenticated device.
+ */
+ rpc GetPreKeyCount (GetPreKeyCountRequest) returns (GetPreKeyCountResponse) {}
+
+ /**
+ * Retrieves a set of pre-keys for establishing a session with the targeted
+ * device or devices. Note that callers with an unidentified access key for
+ * the targeted account should use the version of this method in
+ * `KeysAnonymous` instead.
+ *
+ * This RPC may fail with a `NOT_FOUND` status if the target account was not
+ * found, if no active device with the given ID (if specified) was found on
+ * the target account, or if the account has no active devices. It may also
+ * fail with a `RESOURCE_EXHAUSTED` if a rate limit for fetching keys has been
+ * exceeded, in which case a `retry-after` header containing an ISO 8601
+ * duration string will be present in the response trailers.
+ */
+ rpc GetPreKeys(GetPreKeysRequest) returns (GetPreKeysResponse) {}
+
+ /**
+ * Uploads a new set of one-time EC pre-keys for the authenticated device,
+ * clearing any previously-stored pre-keys. Note that all keys submitted via
+ * a single call to this method _must_ have the same identity type (i.e. if
+ * the first key has an ACI identity type, then all other keys in the same
+ * stream must also have an ACI identity type).
+ *
+ * This method returns a single status code; if all keys were validated/stored
+ * successfully, then this method will return `SET_PRE_KEY_STATUS_OK`. If one
+ * or more keys could not be stored, this method will return a status code
+ * indicating the reason. If multiple keys had problems, which status code
+ * will be returned is not defined.
+ *
+ * This RPC may fail with an `INVALID_ARGUMENT` status if one or more of the
+ * given pre-keys was structurally invalid or if the list of pre-keys was
+ * empty.
+ */
+ rpc SetOneTimeEcPreKeys (SetOneTimeEcPreKeysRequest) returns (SetPreKeyResponse) {}
+
+ /**
+ * Uploads a new set of one-time KEM pre-keys for the authenticated device,
+ * clearing any previously-stored pre-keys. Note that all keys submitted via
+ * a single call to this method _must_ have the same identity type (i.e. if
+ * the first key has an ACI identity type, then all other keys in the same
+ * stream must also have an ACI identity type).
+ *
+ * This method returns a single status code; if all keys were validated/stored
+ * successfully, then this method will return `SET_PRE_KEY_STATUS_OK`. If one
+ * or more keys could not be stored, this method will return a status code
+ * indicating the reason. If multiple keys had problems, which status code
+ * will be returned is not defined.
+ *
+ * This RPC may fail with an `INVALID_ARGUMENT` status if one or more of the
+ * given pre-keys was structurally invalid, had an invalid signature, or if
+ * the list of pre-keys was empty.
+ */
+ rpc SetOneTimeKemSignedPreKeys (SetOneTimeKemSignedPreKeysRequest) returns (SetPreKeyResponse) {}
+
+ /**
+ * Sets the signed EC pre-key for one identity (i.e. ACI or PNI) associated
+ * with the authenticated device.
+ *
+ * This RPC may fail with an `INVALID_ARGUMENT` status if the given pre-key
+ * was structurally invalid, had a bad signature, or was missing entirely.
+ */
+ rpc SetEcSignedPreKey (SetEcSignedPreKeyRequest) returns (SetPreKeyResponse) {}
+
+ /**
+ * Sets the last-resort KEM pre-key for one identity (i.e. ACI or PNI)
+ * associated with the authenticated device.
+ *
+ * This RPC may fail with an `INVALID_ARGUMENT` status if the given pre-key
+ * was structurally invalid, had a bad signature, or was missing entirely.
+ */
+ rpc SetKemLastResortPreKey (SetKemLastResortPreKeyRequest) returns (SetPreKeyResponse) {}
+}
+
+/**
+ * Provides methods for working with pre-keys using "unidentified access"
+ * credentials.
+ */
+service KeysAnonymous {
+
+ /**
+ * Retrieves a set of pre-keys for establishing a session with the targeted
+ * device or devices. Callers must not submit any self-identifying credentials
+ * when calling this method and must instead present the targeted account's
+ * unidentified access key as an anonymous authentication mechanism. Callers
+ * without an unidentified access key should use the equivalent, authenticated
+ * method in `Keys` instead.
+ *
+ * This RPC may fail with an `UNAUTHENTICATED` status if the given
+ * unidentified access key did not match the target account's unidentified
+ * access key or if the target account was not found. It may also fail with a
+ * `NOT_FOUND` status if no active device with the given ID (if specified) was
+ * found on the target account, or if the target account has no active
+ * devices.
+ */
+ rpc GetPreKeys(GetPreKeysAnonymousRequest) returns (GetPreKeysResponse) {}
+}
+
+message GetPreKeyCountRequest {
+}
+
+message GetPreKeyCountResponse {
+ /**
+ * The approximate number of one-time EC pre-keys stored for the
+ * authenticated device and associated with the caller's ACI.
+ */
+ uint32 aci_ec_pre_key_count = 1;
+
+ /**
+ * The approximate number of one-time Kyber pre-keys stored for the
+ * authenticated device and associated with the caller's ACI.
+ */
+ uint32 aci_kem_pre_key_count = 2;
+
+ /**
+ * The approximate number of one-time EC pre-keys stored for the
+ * authenticated device and associated with the caller's PNI.
+ */
+ uint32 pni_ec_pre_key_count = 3;
+
+ /**
+ * The approximate number of one-time KEM pre-keys stored for the
+ * authenticated device and associated with the caller's PNI.
+ */
+ uint32 pni_kem_pre_key_count = 4;
+}
+
+message GetPreKeysRequest {
+ /**
+ * The service identifier of the account for which to retrieve pre-keys.
+ */
+ common.ServiceIdentifier target_identifier = 1;
+
+ /**
+ * The ID of the device associated with the targeted account for which to
+ * retrieve pre-keys. If not set, pre-keys are returned for all devices
+ * associated with the targeted account.
+ */
+ uint64 device_id = 2;
+}
+
+message GetPreKeysAnonymousRequest {
+ /**
+ * The service identifier of the account for which to retrieve pre-keys.
+ */
+ common.ServiceIdentifier target_identifier = 1;
+
+ /**
+ * The ID of the device associated with the targeted account for which to
+ * retrieve pre-keys. If not set, pre-keys are returned for all devices
+ * associated with the targeted account.
+ */
+ uint64 device_id = 2;
+
+ /**
+ * The unidentified access key (UAK) for the targeted account.
+ */
+ bytes unidentified_access_key = 3;
+}
+
+message GetPreKeysResponse {
+ message PreKeyBundle {
+ /**
+ * The EC signed pre-key associated with the targeted
+ * account/device/identity.
+ */
+ common.EcSignedPreKey ec_signed_pre_key = 1;
+
+ /**
+ * A one-time EC pre-key for the targeted account/device/identity. May not
+ * be set if no one-time EC pre-keys are available.
+ */
+ common.EcPreKey ec_one_time_pre_key = 2;
+
+ /**
+ * A one-time KEM pre-key (or a last-resort KEM pre-key) for the targeted
+ * account/device/identity. May not be set if the targeted device has not
+ * yet uploaded any KEM pre-keys.
+ */
+ common.KemSignedPreKey kem_one_time_pre_key = 3;
+ }
+
+ /**
+ * The identity key associated with the targeted account/identity.
+ */
+ bytes identity_key = 1;
+
+ /**
+ * A map of device IDs to pre-key "bundles" for the targeted account.
+ */
+ map pre_keys = 2;
+}
+
+message SetOneTimeEcPreKeysRequest {
+ /**
+ * The identity type (i.e. ACI/PNI) with which the keys in this request are
+ * associated.
+ */
+ common.IdentityType identity_type = 1;
+
+ /**
+ * The unsigned EC pre-keys to be stored.
+ */
+ repeated common.EcPreKey pre_keys = 2;
+}
+
+message SetOneTimeKemSignedPreKeysRequest {
+ /**
+ * The identity type (i.e. ACI/PNI) with which the keys in this request are
+ * associated.
+ */
+ common.IdentityType identity_type = 1;
+
+ /**
+ * The KEM pre-keys to be stored.
+ */
+ repeated common.KemSignedPreKey pre_keys = 2;
+}
+
+message SetEcSignedPreKeyRequest {
+ /**
+ * The identity type (i.e. ACI/PNI) with which this key is associated.
+ */
+ common.IdentityType identity_type = 1;
+
+ /**
+ * The signed EC pre-key itself.
+ */
+ common.EcSignedPreKey signed_pre_key = 2;
+}
+
+message SetKemLastResortPreKeyRequest {
+ /**
+ * The identity type (i.e. ACI/PNI) with which this key is associated.
+ */
+ common.IdentityType identity_type = 1;
+
+ /**
+ * The signed KEM pre-key itself.
+ */
+ common.KemSignedPreKey signed_pre_key = 2;
+}
+
+message SetPreKeyResponse {
+}
+
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java
new file mode 100644
index 000000000..993974b07
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/UnidentifiedAccessUtilTest.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.security.SecureRandom;
+import java.util.Optional;
+import java.util.stream.Stream;
+import javax.annotation.Nullable;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.whispersystems.textsecuregcm.storage.Account;
+
+class UnidentifiedAccessUtilTest {
+
+ @ParameterizedTest
+ @MethodSource
+ void checkUnidentifiedAccess(@Nullable final byte[] targetUak,
+ final boolean unrestrictedUnidentifiedAccess,
+ final byte[] presentedUak,
+ final boolean expectAccessAllowed) {
+
+ final Account account = mock(Account.class);
+ when(account.getUnidentifiedAccessKey()).thenReturn(Optional.ofNullable(targetUak));
+ when(account.isUnrestrictedUnidentifiedAccess()).thenReturn(unrestrictedUnidentifiedAccess);
+
+ assertEquals(expectAccessAllowed, UnidentifiedAccessUtil.checkUnidentifiedAccess(account, presentedUak));
+ }
+
+ private static Stream checkUnidentifiedAccess() {
+ final byte[] uak = new byte[16];
+ new SecureRandom().nextBytes(uak);
+
+ final byte[] incorrectUak = new byte[uak.length + 1];
+
+ return Stream.of(
+ Arguments.of(null, false, uak, false),
+ Arguments.of(null, true, uak, true),
+ Arguments.of(uak, false, incorrectUak, false),
+ Arguments.of(uak, false, uak, true),
+ Arguments.of(uak, true, incorrectUak, true),
+ Arguments.of(uak, true, uak, true)
+ );
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java
new file mode 100644
index 000000000..a4fd52df8
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth.grpc;
+
+import io.grpc.Context;
+import io.grpc.Contexts;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import org.whispersystems.textsecuregcm.util.Pair;
+
+public class MockAuthenticationInterceptor implements ServerInterceptor {
+
+ @Nullable
+ private Pair authenticatedDevice;
+
+ public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) {
+ authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
+ }
+
+ public void clearAuthenticatedDevice() {
+ authenticatedDevice = null;
+ }
+
+ @Override
+ public ServerCall.Listener interceptCall(final ServerCall call,
+ final Metadata headers,
+ final ServerCallHandler next) {
+
+ if (authenticatedDevice != null) {
+ final Context context = Context.current()
+ .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY, authenticatedDevice.first())
+ .withValue(AuthenticationUtil.CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY, authenticatedDevice.second());
+
+ return Contexts.interceptCall(context, call, headers, next);
+ }
+
+ return next.startCall(call, headers);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java
new file mode 100644
index 000000000..725ecaa9d
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/GrpcServerExtension.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import io.grpc.BindableService;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.util.MutableHandlerRegistry;
+import org.junit.jupiter.api.extension.AfterEachCallback;
+import org.junit.jupiter.api.extension.BeforeEachCallback;
+import org.junit.jupiter.api.extension.ExtensionContext;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+
+// This is mostly a direct port of
+// https://github.com/grpc/grpc-java/blob/master/testing/src/main/java/io/grpc/testing/GrpcServerRule.java, but for
+// JUnit 5.
+public class GrpcServerExtension implements BeforeEachCallback, AfterEachCallback {
+
+ private ManagedChannel channel;
+ private Server server;
+ private String serverName;
+ private MutableHandlerRegistry serviceRegistry;
+ private boolean useDirectExecutor;
+
+ /**
+ * Returns {@code this} configured to use a direct executor for the {@link ManagedChannel} and
+ * {@link Server}. This can only be called at the rule instantiation.
+ */
+ public final GrpcServerExtension directExecutor() {
+ if (serverName != null) {
+ throw new IllegalStateException("directExecutor() can only be called at the rule instantiation");
+ }
+
+ useDirectExecutor = true;
+ return this;
+ }
+
+ /**
+ * Returns a {@link ManagedChannel} connected to this service.
+ */
+ public final ManagedChannel getChannel() {
+ return channel;
+ }
+
+ /**
+ * Returns the underlying gRPC {@link Server} for this service.
+ */
+ public final Server getServer() {
+ return server;
+ }
+
+ /**
+ * Returns the randomly generated server name for this service.
+ */
+ public final String getServerName() {
+ return serverName;
+ }
+
+ /**
+ * Returns the service registry for this service. The registry is used to add service instances
+ * (e.g. {@link BindableService} or {@link ServerServiceDefinition} to the server.
+ */
+ public final MutableHandlerRegistry getServiceRegistry() {
+ return serviceRegistry;
+ }
+
+ @Override
+ public void beforeEach(final ExtensionContext extensionContext) throws Exception {
+ serverName = UUID.randomUUID().toString();
+ serviceRegistry = new MutableHandlerRegistry();
+
+ final InProcessServerBuilder serverBuilder = InProcessServerBuilder.forName(serverName)
+ .fallbackHandlerRegistry(serviceRegistry);
+
+ if (useDirectExecutor) {
+ serverBuilder.directExecutor();
+ }
+
+ server = serverBuilder.build().start();
+
+ final InProcessChannelBuilder channelBuilder = InProcessChannelBuilder.forName(serverName);
+
+ if (useDirectExecutor) {
+ channelBuilder.directExecutor();
+ }
+
+ channel = channelBuilder.build();
+ }
+
+ @Override
+ public void afterEach(final ExtensionContext extensionContext) throws Exception {
+ serverName = null;
+ serviceRegistry = null;
+
+ channel.shutdown();
+ server.shutdown();
+
+ try {
+ channel.awaitTermination(1, TimeUnit.MINUTES);
+ server.awaitTermination(1, TimeUnit.MINUTES);
+ } catch (final InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ } finally {
+ channel.shutdownNow();
+ channel = null;
+
+ server.shutdownNow();
+ server = null;
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
new file mode 100644
index 000000000..81cf7e68e
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcServiceTest.java
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.google.protobuf.ByteString;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import java.security.SecureRandom;
+import java.util.Collections;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.signal.chat.common.EcPreKey;
+import org.signal.chat.common.EcSignedPreKey;
+import org.signal.chat.common.IdentityType;
+import org.signal.chat.common.KemSignedPreKey;
+import org.signal.chat.common.ServiceIdentifier;
+import org.signal.chat.keys.GetPreKeysAnonymousRequest;
+import org.signal.chat.keys.GetPreKeysResponse;
+import org.signal.chat.keys.KeysAnonymousGrpc;
+import org.signal.libsignal.protocol.IdentityKey;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.whispersystems.textsecuregcm.entities.ECPreKey;
+import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
+import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.KeysManager;
+import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+
+class KeysAnonymousGrpcServiceTest {
+
+ private AccountsManager accountsManager;
+ private KeysManager keysManager;
+
+ private KeysAnonymousGrpc.KeysAnonymousBlockingStub keysAnonymousStub;
+
+ @RegisterExtension
+ static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
+
+ @BeforeEach
+ void setUp() {
+ accountsManager = mock(AccountsManager.class);
+ keysManager = mock(KeysManager.class);
+
+ final KeysAnonymousGrpcService keysGrpcService =
+ new KeysAnonymousGrpcService(accountsManager, keysManager);
+
+ keysAnonymousStub = KeysAnonymousGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
+
+ GRPC_SERVER_EXTENSION.getServiceRegistry().addService(keysGrpcService);
+ }
+
+ @Test
+ void getPreKeys() {
+ final Account targetAccount = mock(Account.class);
+ final Device targetDevice = mock(Device.class);
+
+ final ECKeyPair identityKeyPair = Curve.generateKeyPair();
+ final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
+ final UUID identifier = UUID.randomUUID();
+
+ final byte[] unidentifiedAccessKey = new byte[16];
+ new SecureRandom().nextBytes(unidentifiedAccessKey);
+
+ when(targetDevice.getId()).thenReturn(Device.MASTER_ID);
+ when(targetDevice.isEnabled()).thenReturn(true);
+ when(targetAccount.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(targetDevice));
+
+ when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
+ when(targetAccount.getUuid()).thenReturn(identifier);
+ when(targetAccount.getIdentityKey()).thenReturn(identityKey);
+ when(accountsManager.getByAccountIdentifierAsync(identifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+
+ final ECPreKey ecPreKey = new ECPreKey(1, Curve.generateKeyPair().getPublicKey());
+ final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(2, identityKeyPair);
+ final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(3, identityKeyPair);
+
+ when(keysManager.takeEC(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey)));
+ when(keysManager.takePQ(identifier, Device.MASTER_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
+ when(targetDevice.getSignedPreKey()).thenReturn(ecSignedPreKey);
+
+ final GetPreKeysResponse response = keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(identifier))
+ .build())
+ .setDeviceId(Device.MASTER_ID)
+ .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
+ .build());
+
+ final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
+ .setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
+ .putPreKeys(Device.MASTER_ID, GetPreKeysResponse.PreKeyBundle.newBuilder()
+ .setEcOneTimePreKey(EcPreKey.newBuilder()
+ .setKeyId(ecPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
+ .build())
+ .setEcSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(ecSignedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
+ .build())
+ .setKemOneTimePreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(kemSignedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
+ .build())
+ .build())
+ .build();
+
+ assertEquals(expectedResponse, response);
+ }
+
+ @Test
+ void getPreKeysIncorrectUnidentifiedAccessKey() {
+ final Account targetAccount = mock(Account.class);
+
+ final ECKeyPair identityKeyPair = Curve.generateKeyPair();
+ final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
+ final UUID identifier = UUID.randomUUID();
+
+ final byte[] unidentifiedAccessKey = new byte[16];
+ new SecureRandom().nextBytes(unidentifiedAccessKey);
+
+ when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
+ when(targetAccount.getUuid()).thenReturn(identifier);
+ when(targetAccount.getIdentityKey()).thenReturn(identityKey);
+ when(accountsManager.getByAccountIdentifierAsync(identifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException statusRuntimeException =
+ assertThrows(StatusRuntimeException.class,
+ () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(identifier))
+ .build())
+ .setDeviceId(Device.MASTER_ID)
+ .build()));
+
+ assertEquals(Status.UNAUTHENTICATED.getCode(), statusRuntimeException.getStatus().getCode());
+ }
+
+ @Test
+ void getPreKeysAccountNotFound() {
+ when(accountsManager.getByAccountIdentifierAsync(any()))
+ .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setUnidentifiedAccessKey(UUIDUtil.toByteString(UUID.randomUUID()))
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
+ .build())
+ .build()));
+
+ assertEquals(Status.Code.UNAUTHENTICATED, exception.getStatus().getCode());
+ }
+
+ @ParameterizedTest
+ @ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
+ void getPreKeysDeviceNotFound(final long deviceId) {
+ final UUID accountIdentifier = UUID.randomUUID();
+
+ final byte[] unidentifiedAccessKey = new byte[16];
+ new SecureRandom().nextBytes(unidentifiedAccessKey);
+
+ final Account targetAccount = mock(Account.class);
+ when(targetAccount.getUuid()).thenReturn(accountIdentifier);
+ when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
+ when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
+ when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
+ when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
+
+ when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysAnonymousStub.getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
+ .setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(accountIdentifier))
+ .build())
+ .setDeviceId(deviceId)
+ .build()));
+
+ assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
new file mode 100644
index 000000000..d2431a281
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcServiceTest.java
@@ -0,0 +1,678 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.grpc;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.protobuf.ByteString;
+import io.grpc.ServerInterceptors;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+import java.util.stream.Stream;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.EnumSource;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.signal.chat.common.EcPreKey;
+import org.signal.chat.common.EcSignedPreKey;
+import org.signal.chat.common.KemSignedPreKey;
+import org.signal.chat.common.ServiceIdentifier;
+import org.signal.chat.keys.GetPreKeyCountRequest;
+import org.signal.chat.keys.GetPreKeyCountResponse;
+import org.signal.chat.keys.GetPreKeysRequest;
+import org.signal.chat.keys.GetPreKeysResponse;
+import org.signal.chat.keys.KeysGrpc;
+import org.signal.chat.keys.SetEcSignedPreKeyRequest;
+import org.signal.chat.keys.SetKemLastResortPreKeyRequest;
+import org.signal.chat.keys.SetOneTimeEcPreKeysRequest;
+import org.signal.chat.keys.SetOneTimeKemSignedPreKeysRequest;
+import org.signal.libsignal.protocol.IdentityKey;
+import org.signal.libsignal.protocol.ecc.Curve;
+import org.signal.libsignal.protocol.ecc.ECKeyPair;
+import org.whispersystems.textsecuregcm.auth.grpc.MockAuthenticationInterceptor;
+import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.entities.ECPreKey;
+import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
+import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
+import org.whispersystems.textsecuregcm.limits.RateLimiter;
+import org.whispersystems.textsecuregcm.limits.RateLimiters;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.KeysManager;
+import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
+import org.whispersystems.textsecuregcm.util.UUIDUtil;
+import reactor.core.publisher.Mono;
+
+class KeysGrpcServiceTest {
+
+ private AccountsManager accountsManager;
+ private KeysManager keysManager;
+ private RateLimiter preKeysRateLimiter;
+
+ private Device authenticatedDevice;
+
+ private KeysGrpc.KeysBlockingStub keysStub;
+
+ private static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
+ private static final UUID AUTHENTICATED_PNI = UUID.randomUUID();
+ private static final long AUTHENTICATED_DEVICE_ID = Device.MASTER_ID;
+
+ private static final ECKeyPair ACI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
+ private static final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
+
+ @RegisterExtension
+ static final GrpcServerExtension GRPC_SERVER_EXTENSION = new GrpcServerExtension();
+
+ @BeforeEach
+ void setUp() {
+ accountsManager = mock(AccountsManager.class);
+ keysManager = mock(KeysManager.class);
+ preKeysRateLimiter = mock(RateLimiter.class);
+
+ final RateLimiters rateLimiters = mock(RateLimiters.class);
+ when(rateLimiters.getPreKeysLimiter()).thenReturn(preKeysRateLimiter);
+
+ when(preKeysRateLimiter.validateReactive(anyString())).thenReturn(Mono.empty());
+
+ final KeysGrpcService keysGrpcService = new KeysGrpcService(accountsManager, keysManager, rateLimiters);
+ keysStub = KeysGrpc.newBlockingStub(GRPC_SERVER_EXTENSION.getChannel());
+
+ authenticatedDevice = mock(Device.class);
+ when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
+
+ final Account authenticatedAccount = mock(Account.class);
+ when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
+ when(authenticatedAccount.getPhoneNumberIdentifier()).thenReturn(AUTHENTICATED_PNI);
+ when(authenticatedAccount.getIdentityKey()).thenReturn(new IdentityKey(ACI_IDENTITY_KEY_PAIR.getPublicKey()));
+ when(authenticatedAccount.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()));
+ when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(authenticatedDevice));
+
+ final MockAuthenticationInterceptor mockAuthenticationInterceptor = new MockAuthenticationInterceptor();
+ mockAuthenticationInterceptor.setAuthenticatedDevice(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID);
+
+ GRPC_SERVER_EXTENSION.getServiceRegistry()
+ .addService(ServerInterceptors.intercept(keysGrpcService, mockAuthenticationInterceptor));
+
+ when(accountsManager.getByAccountIdentifier(AUTHENTICATED_ACI)).thenReturn(Optional.of(authenticatedAccount));
+ when(accountsManager.getByPhoneNumberIdentifier(AUTHENTICATED_PNI)).thenReturn(Optional.of(authenticatedAccount));
+
+ when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
+ when(accountsManager.getByPhoneNumberIdentifierAsync(AUTHENTICATED_PNI)).thenReturn(CompletableFuture.completedFuture(Optional.of(authenticatedAccount)));
+ }
+
+ @Test
+ void getPreKeyCount() {
+ when(keysManager.getEcCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(1));
+
+ when(keysManager.getPqCount(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(2));
+
+ when(keysManager.getEcCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(3));
+
+ when(keysManager.getPqCount(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID))
+ .thenReturn(CompletableFuture.completedFuture(4));
+
+ assertEquals(GetPreKeyCountResponse.newBuilder()
+ .setAciEcPreKeyCount(1)
+ .setAciKemPreKeyCount(2)
+ .setPniEcPreKeyCount(3)
+ .setPniKemPreKeyCount(4)
+ .build(),
+ keysStub.getPreKeyCount(GetPreKeyCountRequest.newBuilder().build()));
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
+ void setOneTimeEcPreKeys(final org.signal.chat.common.IdentityType identityType) {
+ final List preKeys = new ArrayList<>();
+
+ for (int keyId = 0; keyId < 100; keyId++) {
+ preKeys.add(new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()));
+ }
+
+ when(keysManager.storeEcOneTimePreKeys(any(), anyLong(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ //noinspection ResultOfMethodCallIgnored
+ keysStub.setOneTimeEcPreKeys(SetOneTimeEcPreKeysRequest.newBuilder()
+ .setIdentityType(identityType)
+ .addAllPreKeys(preKeys.stream()
+ .map(preKey -> EcPreKey.newBuilder()
+ .setKeyId(preKey.keyId())
+ .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
+ .build())
+ .toList())
+ .build());
+
+ final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> AUTHENTICATED_ACI;
+ case PNI -> AUTHENTICATED_PNI;
+ };
+
+ verify(keysManager).storeEcOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys);
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void setOneTimeEcPreKeysWithError(final SetOneTimeEcPreKeysRequest request) {
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeEcPreKeys(request));
+
+ assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ }
+
+ private static Stream setOneTimeEcPreKeysWithError() {
+ return Stream.of(
+ // Missing identity type
+ Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
+ .addPreKeys(EcPreKey.newBuilder()
+ .setKeyId(1)
+ .setPublicKey(ByteString.copyFrom(Curve.generateKeyPair().getPublicKey().serialize()))
+ .build())
+ .build()),
+
+ // Invalid public key
+ Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .addPreKeys(EcPreKey.newBuilder()
+ .setKeyId(1)
+ .setPublicKey(ByteString.empty())
+ .build())
+ .build()),
+
+ // No keys
+ Arguments.of(SetOneTimeEcPreKeysRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .build())
+ );
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
+ void setOneTimeKemSignedPreKeys(final org.signal.chat.common.IdentityType identityType) {
+ final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> ACI_IDENTITY_KEY_PAIR;
+ case PNI -> PNI_IDENTITY_KEY_PAIR;
+ };
+
+ final List preKeys = new ArrayList<>();
+
+ for (int keyId = 0; keyId < 100; keyId++) {
+ preKeys.add(KeysHelper.signedKEMPreKey(keyId, identityKeyPair));
+ }
+
+ when(keysManager.storeKemOneTimePreKeys(any(), anyLong(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ //noinspection ResultOfMethodCallIgnored
+ keysStub.setOneTimeKemSignedPreKeys(
+ SetOneTimeKemSignedPreKeysRequest.newBuilder()
+ .setIdentityType(identityType)
+ .addAllPreKeys(preKeys.stream()
+ .map(preKey -> KemSignedPreKey.newBuilder()
+ .setKeyId(preKey.keyId())
+ .setPublicKey(ByteString.copyFrom(preKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(preKey.signature()))
+ .build())
+ .toList())
+ .build());
+
+ final UUID expectedIdentifier = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> AUTHENTICATED_ACI;
+ case PNI -> AUTHENTICATED_PNI;
+ };
+
+ verify(keysManager).storeKemOneTimePreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, preKeys);
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void setOneTimeKemSignedPreKeysWithError(final SetOneTimeKemSignedPreKeysRequest request) {
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.setOneTimeKemSignedPreKeys(request));
+
+ assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ }
+
+ private static Stream setOneTimeKemSignedPreKeysWithError() {
+ final KEMSignedPreKey signedPreKey = KeysHelper.signedKEMPreKey(1, ACI_IDENTITY_KEY_PAIR);
+
+ return Stream.of(
+ // Missing identity type
+ Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
+ .addPreKeys(KemSignedPreKey.newBuilder()
+ .setKeyId(1)
+ .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(signedPreKey.signature()))
+ .build())
+ .build()),
+
+ // Invalid public key
+ Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .addPreKeys(KemSignedPreKey.newBuilder()
+ .setKeyId(1)
+ .setPublicKey(ByteString.empty())
+ .setSignature(ByteString.copyFrom(signedPreKey.signature()))
+ .build())
+ .build()),
+
+ // Invalid signature
+ Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .addPreKeys(KemSignedPreKey.newBuilder()
+ .setKeyId(1)
+ .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.empty())
+ .build())
+ .build()),
+
+ // No keys
+ Arguments.of(SetOneTimeKemSignedPreKeysRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .build())
+ );
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
+ void setSignedPreKey(final org.signal.chat.common.IdentityType identityType) {
+ when(accountsManager.updateDeviceAsync(any(), anyLong(), any())).thenAnswer(invocation -> {
+ final Account account = invocation.getArgument(0);
+ final long deviceId = invocation.getArgument(1);
+ final Consumer deviceUpdater = invocation.getArgument(2);
+
+ account.getDevice(deviceId).ifPresent(deviceUpdater);
+
+ return CompletableFuture.completedFuture(account);
+ });
+
+ when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
+
+ final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> ACI_IDENTITY_KEY_PAIR;
+ case PNI -> PNI_IDENTITY_KEY_PAIR;
+ };
+
+ final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, identityKeyPair);
+
+ //noinspection ResultOfMethodCallIgnored
+ keysStub.setEcSignedPreKey(SetEcSignedPreKeyRequest.newBuilder()
+ .setIdentityType(identityType)
+ .setSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(signedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(signedPreKey.signature()))
+ .build())
+ .build());
+
+ switch (identityType) {
+ case IDENTITY_TYPE_ACI -> {
+ verify(authenticatedDevice).setSignedPreKey(signedPreKey);
+ verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
+ }
+
+ case IDENTITY_TYPE_PNI -> {
+ verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey);
+ verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, Map.of(AUTHENTICATED_DEVICE_ID, signedPreKey));
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void setSignedPreKeyWithError(final SetEcSignedPreKeyRequest request) {
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.setEcSignedPreKey(request));
+
+ assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ }
+
+ private static Stream setSignedPreKeyWithError() {
+ final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(17, ACI_IDENTITY_KEY_PAIR);
+
+ return Stream.of(
+ // Missing identity type
+ Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
+ .setSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(signedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(signedPreKey.signature()))
+ .build())
+ .build()),
+
+ // Invalid public key
+ Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(signedPreKey.keyId())
+ .setPublicKey(ByteString.empty())
+ .setSignature(ByteString.copyFrom(signedPreKey.signature()))
+ .build())
+ .build()),
+
+ // Invalid signature
+ Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(signedPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(signedPreKey.serializedPublicKey()))
+ .setSignature(ByteString.empty())
+ .build())
+ .build()),
+
+ // Missing key
+ Arguments.of(SetEcSignedPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .build())
+ );
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
+ void setLastResortPreKey(final org.signal.chat.common.IdentityType identityType) {
+ when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
+
+ final ECKeyPair identityKeyPair = switch (IdentityType.fromGrpcIdentityType(identityType)) {
+ case ACI -> ACI_IDENTITY_KEY_PAIR;
+ case PNI -> PNI_IDENTITY_KEY_PAIR;
+ };
+
+ final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, identityKeyPair);
+
+ //noinspection ResultOfMethodCallIgnored
+ keysStub.setKemLastResortPreKey(SetKemLastResortPreKeyRequest.newBuilder()
+ .setIdentityType(identityType)
+ .setSignedPreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(lastResortPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
+ .build())
+ .build());
+
+ final UUID expectedIdentifier = switch (identityType) {
+ case IDENTITY_TYPE_ACI -> AUTHENTICATED_ACI;
+ case IDENTITY_TYPE_PNI -> AUTHENTICATED_PNI;
+ case IDENTITY_TYPE_UNSPECIFIED, UNRECOGNIZED -> throw new AssertionError("Bad identity type");
+ };
+
+ verify(keysManager).storePqLastResort(expectedIdentifier, Map.of(AUTHENTICATED_DEVICE_ID, lastResortPreKey));
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void setLastResortPreKeyWithError(final SetKemLastResortPreKeyRequest request) {
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.setKemLastResortPreKey(request));
+
+ assertEquals(Status.INVALID_ARGUMENT.getCode(), exception.getStatus().getCode());
+ }
+
+ private static Stream setLastResortPreKeyWithError() {
+ final KEMSignedPreKey lastResortPreKey = KeysHelper.signedKEMPreKey(17, ACI_IDENTITY_KEY_PAIR);
+
+ return Stream.of(
+ // No identity type
+ Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
+ .setSignedPreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(lastResortPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
+ .build())
+ .build()),
+
+ // Bad public key
+ Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setSignedPreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(lastResortPreKey.keyId())
+ .setPublicKey(ByteString.empty())
+ .setSignature(ByteString.copyFrom(lastResortPreKey.signature()))
+ .build())
+ .build()),
+
+ // Bad signature
+ Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setSignedPreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(lastResortPreKey.keyId())
+ .setPublicKey(ByteString.copyFrom(lastResortPreKey.serializedPublicKey()))
+ .setSignature(ByteString.empty())
+ .build())
+ .build()),
+
+ // Missing key
+ Arguments.of(SetKemLastResortPreKeyRequest.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .build())
+ );
+ }
+
+ @ParameterizedTest
+ @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
+ void getPreKeys(final org.signal.chat.common.IdentityType identityType) {
+ final Account targetAccount = mock(Account.class);
+
+ final ECKeyPair identityKeyPair = Curve.generateKeyPair();
+ final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
+ final UUID identifier = UUID.randomUUID();
+
+ if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
+ when(targetAccount.getUuid()).thenReturn(identifier);
+ when(targetAccount.getIdentityKey()).thenReturn(identityKey);
+ when(accountsManager.getByAccountIdentifierAsync(identifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+ } else {
+ when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
+ when(targetAccount.getPhoneNumberIdentifier()).thenReturn(identifier);
+ when(targetAccount.getPhoneNumberIdentityKey()).thenReturn(identityKey);
+ when(accountsManager.getByPhoneNumberIdentifierAsync(identifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+ }
+
+ final Map ecOneTimePreKeys = new HashMap<>();
+ final Map kemPreKeys = new HashMap<>();
+ final Map ecSignedPreKeys = new HashMap<>();
+
+ final Map devices = new HashMap<>();
+
+ for (final long deviceId : List.of(1, 2)) {
+ ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey()));
+ kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair));
+ ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair));
+
+ final Device device = mock(Device.class);
+ when(device.getId()).thenReturn(deviceId);
+ when(device.isEnabled()).thenReturn(true);
+
+ if (identityType == org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI) {
+ when(device.getSignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
+ } else {
+ when(device.getPhoneNumberIdentitySignedPreKey()).thenReturn(ecSignedPreKeys.get(deviceId));
+ }
+
+ devices.put(deviceId, device);
+ when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
+ }
+
+ when(targetAccount.getDevices()).thenReturn(new ArrayList<>(devices.values()));
+
+ ecOneTimePreKeys.forEach((deviceId, preKey) -> when(keysManager.takeEC(identifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
+
+ kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));
+
+ {
+ final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(identityType)
+ .setUuid(UUIDUtil.toByteString(identifier))
+ .build())
+ .setDeviceId(1)
+ .build());
+
+ final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
+ .setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
+ .putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
+ .setEcSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(ecSignedPreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
+ .build())
+ .setEcOneTimePreKey(EcPreKey.newBuilder()
+ .setKeyId(ecOneTimePreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
+ .build())
+ .setKemOneTimePreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(kemPreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
+ .build())
+ .build())
+ .build();
+
+ assertEquals(expectedResponse, response);
+ }
+
+ when(keysManager.takeEC(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+ when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ {
+ final GetPreKeysResponse response = keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(identityType)
+ .setUuid(UUIDUtil.toByteString(identifier))
+ .build())
+ .build());
+
+ final GetPreKeysResponse expectedResponse = GetPreKeysResponse.newBuilder()
+ .setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
+ .putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
+ .setEcSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(ecSignedPreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
+ .build())
+ .setEcOneTimePreKey(EcPreKey.newBuilder()
+ .setKeyId(ecOneTimePreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
+ .build())
+ .setKemOneTimePreKey(KemSignedPreKey.newBuilder()
+ .setKeyId(kemPreKeys.get(1L).keyId())
+ .setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
+ .build())
+ .build())
+ .putPreKeys(2, GetPreKeysResponse.PreKeyBundle.newBuilder()
+ .setEcSignedPreKey(EcSignedPreKey.newBuilder()
+ .setKeyId(ecSignedPreKeys.get(2L).keyId())
+ .setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(2L).serializedPublicKey()))
+ .setSignature(ByteString.copyFrom(ecSignedPreKeys.get(2L).signature()))
+ .build())
+ .build())
+ .build();
+
+ assertEquals(expectedResponse, response);
+ }
+ }
+
+ @Test
+ void getPreKeysAccountNotFound() {
+ when(accountsManager.getByAccountIdentifierAsync(any()))
+ .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
+ .build())
+ .build()));
+
+ assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ }
+
+ @ParameterizedTest
+ @ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
+ void getPreKeysDeviceNotFound(final long deviceId) {
+ final UUID accountIdentifier = UUID.randomUUID();
+
+ final Account targetAccount = mock(Account.class);
+ when(targetAccount.getUuid()).thenReturn(accountIdentifier);
+ when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
+ when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
+ when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
+
+ when(accountsManager.getByAccountIdentifierAsync(accountIdentifier))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(accountIdentifier))
+ .build())
+ .setDeviceId(deviceId)
+ .build()));
+
+ assertEquals(Status.Code.NOT_FOUND, exception.getStatus().getCode());
+ }
+
+ @Test
+ void getPreKeysRateLimited() {
+ final Account targetAccount = mock(Account.class);
+ when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
+ when(targetAccount.getIdentityKey()).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
+ when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
+ when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
+
+ when(accountsManager.getByAccountIdentifierAsync(any()))
+ .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
+
+ final Duration retryAfterDuration = Duration.ofMinutes(7);
+
+ when(preKeysRateLimiter.validateReactive(anyString()))
+ .thenReturn(Mono.error(new RateLimitExceededException(retryAfterDuration, false)));
+
+ @SuppressWarnings("ResultOfMethodCallIgnored") final StatusRuntimeException exception =
+ assertThrows(StatusRuntimeException.class, () -> keysStub.getPreKeys(GetPreKeysRequest.newBuilder()
+ .setTargetIdentifier(ServiceIdentifier.newBuilder()
+ .setIdentityType(org.signal.chat.common.IdentityType.IDENTITY_TYPE_ACI)
+ .setUuid(UUIDUtil.toByteString(UUID.randomUUID()))
+ .build())
+ .build()));
+
+ assertEquals(Status.Code.RESOURCE_EXHAUSTED, exception.getStatus().getCode());
+ assertNotNull(exception.getTrailers());
+ assertEquals(retryAfterDuration, exception.getTrailers().get(RateLimitUtil.RETRY_AFTER_DURATION_KEY));
+ }
+}