diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index e9689fdc5..64a2a6560 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -554,8 +554,6 @@ public class WhisperServerService extends Application authenticatedAccount, @NotNull @Valid final KeyTransparencySearchRequest request) { @@ -104,19 +101,17 @@ public class KeyTransparencyController { .build() )); - return keyTransparencyServiceClient.search( + return new KeyTransparencySearchResponse( + keyTransparencyServiceClient.search( ByteString.copyFrom(request.aci().toCompactByteArray()), ByteString.copyFrom(request.aciIdentityKey().serialize()), request.usernameHash().map(ByteString::copyFrom), maybeE164SearchRequest, request.lastTreeHeadSize(), - request.distinguishedTreeHeadSize(), - KEY_TRANSPARENCY_RPC_TIMEOUT) - .thenApply(KeyTransparencySearchResponse::new).join(); - } catch (final CancellationException exception) { - LOGGER.error("Unexpected cancellation from key transparency service", exception); - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception); - } catch (final CompletionException exception) { + request.distinguishedTreeHeadSize()) + .toByteArray()); + } catch (final StatusRuntimeException exception) { + LOGGER.error("Unexpected error calling key transparency service", exception); handleKeyTransparencyServiceError(exception); } // This is unreachable @@ -140,6 +135,7 @@ public class KeyTransparencyController { @Path("/monitor") @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP) @Produces(MediaType.APPLICATION_JSON) + @ManagedAsync public KeyTransparencyMonitorResponse monitor( @Auth final Optional authenticatedAccount, @NotNull @Valid final KeyTransparencyMonitorRequest request) { @@ -173,13 +169,10 @@ public class KeyTransparencyController { usernameHashMonitorRequest, e164MonitorRequest, request.lastNonDistinguishedTreeHeadSize(), - request.lastDistinguishedTreeHeadSize(), - KEY_TRANSPARENCY_RPC_TIMEOUT).join()); - - } catch (final CancellationException exception) { - LOGGER.error("Unexpected cancellation from key transparency service", exception); - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception); - } catch (final CompletionException exception) { + request.lastDistinguishedTreeHeadSize()) + .toByteArray()); + } catch (final StatusRuntimeException exception) { + LOGGER.error("Unexpected error calling key transparency service", exception); handleKeyTransparencyServiceError(exception); } // This is unreachable @@ -202,6 +195,7 @@ public class KeyTransparencyController { @Path("/distinguished") @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP) @Produces(MediaType.APPLICATION_JSON) + @ManagedAsync public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey( @Auth final Optional authenticatedAccount, @@ -212,34 +206,26 @@ public class KeyTransparencyController { requireNotAuthenticated(authenticatedAccount); try { - return keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize, KEY_TRANSPARENCY_RPC_TIMEOUT) - .thenApply(KeyTransparencyDistinguishedKeyResponse::new) - .join(); - } catch (final CancellationException exception) { - LOGGER.error("Unexpected cancellation from key transparency service", exception); - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception); - } catch (final CompletionException exception) { + return new KeyTransparencyDistinguishedKeyResponse( + keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize) + .toByteArray()); + } catch (final StatusRuntimeException exception) { + LOGGER.error("Unexpected error calling key transparency service", exception); handleKeyTransparencyServiceError(exception); } // This is unreachable return null; } - private void handleKeyTransparencyServiceError(final CompletionException exception) { - final Throwable unwrapped = ExceptionUtils.unwrap(exception); - - if (unwrapped instanceof StatusRuntimeException e) { - final Status.Code code = e.getStatus().getCode(); - final String description = e.getStatus().getDescription(); - switch (code) { - case NOT_FOUND -> throw new NotFoundException(description); - case PERMISSION_DENIED -> throw new ForbiddenException(description); - case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422); - default -> throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped); - } + private void handleKeyTransparencyServiceError(final StatusRuntimeException exception) { + final Status.Code code = exception.getStatus().getCode(); + final String description = exception.getStatus().getDescription(); + switch (code) { + case NOT_FOUND -> throw new NotFoundException(description); + case PERMISSION_DENIED -> throw new ForbiddenException(description); + case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422); + default -> throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, exception); } - LOGGER.error("Unexpected key transparency service failure", unwrapped); - throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped); } private void requireNotAuthenticated(final Optional authenticatedAccount) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcService.java new file mode 100644 index 000000000..5c1bf7c69 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcService.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import org.signal.keytransparency.client.AciMonitorRequest; +import org.signal.keytransparency.client.ConsistencyParameters; +import org.signal.keytransparency.client.DistinguishedRequest; +import org.signal.keytransparency.client.DistinguishedResponse; +import org.signal.keytransparency.client.E164MonitorRequest; +import org.signal.keytransparency.client.E164SearchRequest; +import org.signal.keytransparency.client.MonitorRequest; +import org.signal.keytransparency.client.MonitorResponse; +import org.signal.keytransparency.client.SearchRequest; +import org.signal.keytransparency.client.SearchResponse; +import org.signal.keytransparency.client.SimpleKeyTransparencyQueryServiceGrpc; +import org.signal.keytransparency.client.UsernameHashMonitorRequest; +import org.whispersystems.textsecuregcm.controllers.AccountController; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; +import org.whispersystems.textsecuregcm.limits.RateLimiters; + +public class KeyTransparencyGrpcService extends + SimpleKeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceImplBase { + @VisibleForTesting + static final int COMMITMENT_INDEX_LENGTH = 32; + private final RateLimiters rateLimiters; + private final KeyTransparencyServiceClient client; + + public KeyTransparencyGrpcService(final RateLimiters rateLimiters, + final KeyTransparencyServiceClient client) { + this.rateLimiters = rateLimiters; + this.client = client; + } + + @Override + public SearchResponse search(final SearchRequest request) throws RateLimitExceededException { + rateLimiters.getKeyTransparencySearchLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress()); + return client.search(validateSearchRequest(request)); + } + + @Override + public MonitorResponse monitor(final MonitorRequest request) throws RateLimitExceededException { + rateLimiters.getKeyTransparencyMonitorLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress()); + return client.monitor(validateMonitorRequest(request)); + } + + @Override + public DistinguishedResponse distinguished(final DistinguishedRequest request) throws RateLimitExceededException { + rateLimiters.getKeyTransparencyDistinguishedLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress()); + // A client's very first distinguished request will not have a "last" parameter + if (request.hasLast() && request.getLast() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException(); + } + return client.distinguished(request); + } + + private SearchRequest validateSearchRequest(final SearchRequest request) { + if (request.hasE164SearchRequest()) { + final E164SearchRequest e164SearchRequest = request.getE164SearchRequest(); + if (e164SearchRequest.getUnidentifiedAccessKey().isEmpty() != e164SearchRequest.getE164().isEmpty()) { + throw Status.INVALID_ARGUMENT.withDescription("Unidentified access key and E164 must be provided together or not at all").asRuntimeException(); + } + } + + if (!request.getConsistency().hasDistinguished()) { + throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished tree head size").asRuntimeException(); + } + + validateConsistencyParameters(request.getConsistency()); + return request; + } + + private MonitorRequest validateMonitorRequest(final MonitorRequest request) { + final AciMonitorRequest aciMonitorRequest = request.getAci(); + + try { + AciServiceIdentifier.fromBytes(aciMonitorRequest.getAci().toByteArray()); + } catch (IllegalArgumentException e) { + throw Status.INVALID_ARGUMENT.withDescription("Invalid ACI").asRuntimeException(); + } + if (aciMonitorRequest.getEntryPosition() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("Aci entry position must be positive").asRuntimeException(); + } + if (aciMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) { + throw Status.INVALID_ARGUMENT.withDescription("Aci commitment index must be 32 bytes").asRuntimeException(); + } + + if (request.hasUsernameHash()) { + final UsernameHashMonitorRequest usernameHashMonitorRequest = request.getUsernameHash(); + if (usernameHashMonitorRequest.getUsernameHash().isEmpty()) { + throw Status.INVALID_ARGUMENT.withDescription("Username hash cannot be empty").asRuntimeException(); + } + if (usernameHashMonitorRequest.getUsernameHash().size() != AccountController.USERNAME_HASH_LENGTH) { + throw Status.INVALID_ARGUMENT.withDescription("Invalid username hash length").asRuntimeException(); + } + if (usernameHashMonitorRequest.getEntryPosition() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("Username hash entry position must be positive").asRuntimeException(); + } + if (usernameHashMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) { + throw Status.INVALID_ARGUMENT.withDescription("Username hash commitment index must be 32 bytes").asRuntimeException(); + } + } + + if (request.hasE164()) { + final E164MonitorRequest e164MonitorRequest = request.getE164(); + if (e164MonitorRequest.getE164().isEmpty()) { + throw Status.INVALID_ARGUMENT.withDescription("E164 cannot be empty").asRuntimeException(); + } + if (e164MonitorRequest.getEntryPosition() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("E164 entry position must be positive").asRuntimeException(); + } + if (e164MonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) { + throw Status.INVALID_ARGUMENT.withDescription("E164 commitment index must be 32 bytes").asRuntimeException(); + } + } + + if (!request.getConsistency().hasDistinguished() || !request.getConsistency().hasLast()) { + throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished and last tree head sizes").asRuntimeException(); + } + + validateConsistencyParameters(request.getConsistency()); + return request; + } + + private static void validateConsistencyParameters(final ConsistencyParameters consistency) { + if (consistency.getDistinguished() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("Distinguished tree head size must be positive").asRuntimeException(); + } + + if (consistency.hasLast() && consistency.getLast() <= 0) { + throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException(); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java index 972964ad7..e455fb4ae 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java @@ -1,6 +1,5 @@ package org.whispersystems.textsecuregcm.keytransparency; -import com.google.protobuf.AbstractMessageLite; import com.google.protobuf.ByteString; import io.dropwizard.lifecycle.Managed; import io.grpc.ChannelCredentials; @@ -20,44 +19,43 @@ import java.time.Duration; import java.time.Instant; import java.util.Collection; import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import org.signal.keytransparency.client.AciMonitorRequest; import org.signal.keytransparency.client.ConsistencyParameters; import org.signal.keytransparency.client.DistinguishedRequest; +import org.signal.keytransparency.client.DistinguishedResponse; import org.signal.keytransparency.client.E164MonitorRequest; import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc; import org.signal.keytransparency.client.MonitorRequest; +import org.signal.keytransparency.client.MonitorResponse; import org.signal.keytransparency.client.SearchRequest; +import org.signal.keytransparency.client.SearchResponse; import org.signal.keytransparency.client.UsernameHashMonitorRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.util.CompletableFutureUtil; public class KeyTransparencyServiceClient implements Managed { private static final String DAYS_UNTIL_CLIENT_CERTIFICATE_EXPIRATION_GAUGE_NAME = MetricsUtil.name(KeyTransparencyServiceClient.class, "daysUntilClientCertificateExpiration"); + private static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15); private static final Logger logger = LoggerFactory.getLogger(KeyTransparencyServiceClient.class); - private final Executor callbackExecutor; private final String host; private final int port; private final ChannelCredentials tlsChannelCredentials; private ManagedChannel channel; - private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceFutureStub stub; + private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub stub; public KeyTransparencyServiceClient( final String host, final int port, final String tlsCertificate, final String clientCertificate, - final String clientPrivateKey, - final Executor callbackExecutor + final String clientPrivateKey ) throws IOException { this.host = host; this.port = port; @@ -76,7 +74,6 @@ public class KeyTransparencyServiceClient implements Managed { configureClientCertificateMetrics(clientCertificate); } - this.callbackExecutor = callbackExecutor; } private void configureClientCertificateMetrics(String clientCertificate) { @@ -113,14 +110,13 @@ public class KeyTransparencyServiceClient implements Managed { } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public CompletableFuture search( + public SearchResponse search( final ByteString aci, final ByteString aciIdentityKey, final Optional usernameHash, final Optional e164SearchRequest, final Optional lastTreeHeadSize, - final long distinguishedTreeHeadSize, - final Duration timeout) { + final long distinguishedTreeHeadSize) { final SearchRequest.Builder searchRequestBuilder = SearchRequest.newBuilder() .setAci(aci) .setAciIdentityKey(aciIdentityKey); @@ -133,19 +129,20 @@ public class KeyTransparencyServiceClient implements Managed { lastTreeHeadSize.ifPresent(consistency::setLast); searchRequestBuilder.setConsistency(consistency.build()); + return search(searchRequestBuilder.build()); + } - return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) - .search(searchRequestBuilder.build()), callbackExecutor) - .thenApply(AbstractMessageLite::toByteArray); + public SearchResponse search(final SearchRequest request) { + return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT)) + .search(request); } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public CompletableFuture monitor(final AciMonitorRequest aciMonitorRequest, + public MonitorResponse monitor(final AciMonitorRequest aciMonitorRequest, final Optional usernameHashMonitorRequest, final Optional e164MonitorRequest, final long lastTreeHeadSize, - final long distinguishedTreeHeadSize, - final Duration timeout) { + final long distinguishedTreeHeadSize) { final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder() .setAci(aciMonitorRequest) .setConsistency(ConsistencyParameters.newBuilder() @@ -155,20 +152,26 @@ public class KeyTransparencyServiceClient implements Managed { usernameHashMonitorRequest.ifPresent(monitorRequestBuilder::setUsernameHash); e164MonitorRequest.ifPresent(monitorRequestBuilder::setE164); - - return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) - .monitor(monitorRequestBuilder.build()), callbackExecutor) - .thenApply(AbstractMessageLite::toByteArray); + return monitor(monitorRequestBuilder.build()); } + public MonitorResponse monitor(final MonitorRequest request) { + return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT)) + .monitor(request); + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public CompletableFuture getDistinguishedKey(final Optional lastTreeHeadSize, final Duration timeout) { + public DistinguishedResponse getDistinguishedKey(final Optional lastTreeHeadSize) { final DistinguishedRequest request = lastTreeHeadSize.map( last -> DistinguishedRequest.newBuilder().setLast(last).build()) .orElseGet(DistinguishedRequest::getDefaultInstance); - return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)).distinguished(request), - callbackExecutor) - .thenApply(AbstractMessageLite::toByteArray); + return distinguished(request); + } + + public DistinguishedResponse distinguished(final DistinguishedRequest request) { + return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT)) + .distinguished(request); } private static Deadline toDeadline(final Duration timeout) { @@ -180,7 +183,7 @@ public class KeyTransparencyServiceClient implements Managed { channel = Grpc.newChannelBuilderForAddress(host, port, tlsChannelCredentials) .idleTimeout(1, TimeUnit.MINUTES) .build(); - stub = KeyTransparencyQueryServiceGrpc.newFutureStub(channel); + stub = KeyTransparencyQueryServiceGrpc.newBlockingStub(channel); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index e42584913..6f4526bff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -206,4 +206,16 @@ public class RateLimiters extends BaseRateLimiters { public RateLimiter getWaitForTransferArchiveLimiter() { return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE); } + + public RateLimiter getKeyTransparencySearchLimiter() { + return forDescriptor(For.KEY_TRANSPARENCY_SEARCH_PER_IP); + } + + public RateLimiter getKeyTransparencyDistinguishedLimiter() { + return forDescriptor(For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP); + } + + public RateLimiter getKeyTransparencyMonitorLimiter() { + return forDescriptor(For.KEY_TRANSPARENCY_MONITOR_PER_IP); + } } diff --git a/service/src/main/proto/KeyTransparencyService.proto b/service/src/main/proto/KeyTransparencyService.proto index 6cf7c14a8..ad0a46194 100644 --- a/service/src/main/proto/KeyTransparencyService.proto +++ b/service/src/main/proto/KeyTransparencyService.proto @@ -10,6 +10,8 @@ option java_package = "org.signal.keytransparency.client"; package kt_query; +import "org/signal/chat/require.proto"; + /** * An external-facing, read-only key transparency service used by Signal's chat server * to look up and monitor identifiers. @@ -19,8 +21,13 @@ package kt_query; * - A username hash which also maps to an ACI * Separately, the log also stores and periodically updates a fixed value known as the `distinguished` key. * Clients use the verified tree head from looking up this key for future calls to the Search and Monitor endpoints. + * + * Note that this service definition is used in two different contexts: + * 1. Implementing the endpoints with rate-limiting and request validation + * 2. Using the generated client stub to forward requests to the remote key transparency service */ service KeyTransparencyQueryService { + option (org.signal.chat.require.auth) = AUTH_ONLY_ANONYMOUS; /** * An endpoint used by clients to retrieve the most recent distinguished tree * head, which should be used to derive consistency parameters for @@ -44,15 +51,15 @@ message SearchRequest { /** * The ACI to look up in the log. */ - bytes aci = 1; + bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16]; /** * The ACI identity key that the client thinks the ACI maps to in the log. */ - bytes aci_identity_key = 2; + bytes aci_identity_key = 2 [(org.signal.chat.require.nonEmpty) = true]; /** * The username hash to look up in the log. */ - optional bytes username_hash = 3; + optional bytes username_hash = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32]; /** * The E164 to look up in the log along with associated data. */ @@ -60,7 +67,7 @@ message SearchRequest { /** * The tree head size(s) to prove consistency against. */ - ConsistencyParameters consistency = 5; + ConsistencyParameters consistency = 5 [(org.signal.chat.require.present) = true]; } /** @@ -70,7 +77,7 @@ message E164SearchRequest { /** * The E164 that the client wishes to look up in the transparency log. */ - string e164 = 1; + optional string e164 = 1 [(org.signal.chat.require.e164) = true]; /** * The unidentified access key of the account associated with the provided E164. */ @@ -328,28 +335,28 @@ message PrefixSearchResult { } message MonitorRequest { - AciMonitorRequest aci = 1; + AciMonitorRequest aci = 1 [(org.signal.chat.require.present) = true]; optional UsernameHashMonitorRequest username_hash = 2; optional E164MonitorRequest e164 = 3; - ConsistencyParameters consistency = 4; + ConsistencyParameters consistency = 4 [(org.signal.chat.require.present) = true]; } message AciMonitorRequest { - bytes aci = 1; + bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16]; uint64 entry_position = 2; - bytes commitment_index = 3; + bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 32]; } message UsernameHashMonitorRequest { - bytes username_hash = 1; + bytes username_hash = 1 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32]; uint64 entry_position = 2; - bytes commitment_index = 3; + bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32]; } message E164MonitorRequest { - string e164 = 1; + optional string e164 = 1 [(org.signal.chat.require.e164) = true]; uint64 entry_position = 2; - bytes commitment_index = 3; + bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32]; } message MonitorProof { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java index 7dc468b85..37cde928f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java @@ -35,12 +35,8 @@ import jakarta.ws.rs.client.WebTarget; import jakarta.ws.rs.core.Response; import java.io.UncheckedIOException; import java.time.Duration; -import java.util.Collections; -import java.util.List; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.stream.Stream; import javax.annotation.Nullable; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; @@ -54,8 +50,10 @@ import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.signal.keytransparency.client.CondensedTreeSearchResponse; +import org.signal.keytransparency.client.DistinguishedResponse; import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.FullTreeHead; +import org.signal.keytransparency.client.MonitorResponse; import org.signal.keytransparency.client.SearchProof; import org.signal.keytransparency.client.SearchResponse; import org.signal.keytransparency.client.UpdateValue; @@ -81,16 +79,16 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; @ExtendWith(DropwizardExtensionsSupport.class) public class KeyTransparencyControllerTest { - private static final String NUMBER = PhoneNumberUtil.getInstance().format( + public static final String NUMBER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); - private static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID()); - private static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20); + public static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID()); + public static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20); private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER = new TestRemoteAddressFilterProvider("127.0.0.1"); - private static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + public static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); private static final byte[] COMMITMENT_INDEX = new byte[32]; - private static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16]; + public static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16]; private final KeyTransparencyServiceClient keyTransparencyServiceClient = mock(KeyTransparencyServiceClient.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiter searchRatelimiter = mock(RateLimiter.class); @@ -141,8 +139,8 @@ public class KeyTransparencyControllerTest { e164.ifPresent(ignored -> searchResponseBuilder.setE164(CondensedTreeSearchResponse.getDefaultInstance())); usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(CondensedTreeSearchResponse.getDefaultInstance())); - when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any())) - .thenReturn(CompletableFuture.completedFuture(searchResponseBuilder.build().toByteArray())); + when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong())) + .thenReturn(searchResponseBuilder.build()); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") @@ -167,8 +165,7 @@ public class KeyTransparencyControllerTest { ArgumentCaptor> e164Argument = ArgumentCaptor.forClass(Optional.class); verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(), - usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L), - eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); + usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L)); assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray()); assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray()); @@ -218,8 +215,8 @@ public class KeyTransparencyControllerTest { @ParameterizedTest @MethodSource void searchGrpcErrors(final Status grpcStatus, final int httpStatus) { - when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any())) - .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); + when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong())) + .thenThrow(new StatusRuntimeException(grpcStatus)); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") @@ -228,7 +225,7 @@ public class KeyTransparencyControllerTest { Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(), ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), 4L))))) { assertEquals(httpStatus, response.getStatus()); - verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong(), any()); + verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong()); } } @@ -295,8 +292,8 @@ public class KeyTransparencyControllerTest { @Test void monitorSuccess() { - when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any())) - .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); + when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong())) + .thenReturn(MonitorResponse.getDefaultInstance()); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/monitor") @@ -314,7 +311,7 @@ public class KeyTransparencyControllerTest { assertNotNull(keyTransparencyMonitorResponse.serializedResponse()); verify(keyTransparencyServiceClient, times(1)).monitor( - any(), any(), any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); + any(), any(), any(), eq(3L), eq(4L)); } } @@ -337,8 +334,8 @@ public class KeyTransparencyControllerTest { @ParameterizedTest @MethodSource void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) { - when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any())) - .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); + when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong())) + .thenThrow(new StatusRuntimeException(grpcStatus)); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/monitor") @@ -349,7 +346,7 @@ public class KeyTransparencyControllerTest { new KeyTransparencyMonitorRequest.AciMonitor(ACI, 3, COMMITMENT_INDEX), Optional.empty(), Optional.empty(), 3L, 4L))))) { assertEquals(httpStatus, response.getStatus()); - verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong(), any()); + verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong()); } } @@ -500,8 +497,8 @@ public class KeyTransparencyControllerTest { @ParameterizedTest @CsvSource(", 1") void distinguishedSuccess(@Nullable Long lastTreeHeadSize) { - when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) - .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); + when(keyTransparencyServiceClient.getDistinguishedKey(any())) + .thenReturn(DistinguishedResponse.getDefaultInstance()); WebTarget webTarget = resources.getJerseyTest() .target("/v1/key-transparency/distinguished"); @@ -518,8 +515,7 @@ public class KeyTransparencyControllerTest { assertNotNull(distinguishedKeyResponse.serializedResponse()); verify(keyTransparencyServiceClient, times(1)) - .getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)), - eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); + .getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize))); } } @@ -538,15 +534,15 @@ public class KeyTransparencyControllerTest { @ParameterizedTest @MethodSource void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) { - when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) - .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); + when(keyTransparencyServiceClient.getDistinguishedKey(any())) + .thenThrow(new StatusRuntimeException(grpcStatus)); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/distinguished") .request(); try (Response response = request.get()) { assertEquals(httpStatus, response.getStatus()); - verify(keyTransparencyServiceClient).getDistinguishedKey(any(), any()); + verify(keyTransparencyServiceClient).getDistinguishedKey(any()); } } @@ -561,8 +557,8 @@ public class KeyTransparencyControllerTest { @Test void distinguishedInvalidRequest() { - when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) - .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); + when(keyTransparencyServiceClient.getDistinguishedKey(any())) + .thenReturn(DistinguishedResponse.getDefaultInstance()); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/distinguished") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcServiceTest.java new file mode 100644 index 000000000..241b26f3a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/KeyTransparencyGrpcServiceTest.java @@ -0,0 +1,305 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import com.google.protobuf.ByteString; +import io.grpc.Channel; +import io.grpc.Status; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.signal.keytransparency.client.AciMonitorRequest; +import org.signal.keytransparency.client.ConsistencyParameters; +import org.signal.keytransparency.client.DistinguishedRequest; +import org.signal.keytransparency.client.DistinguishedResponse; +import org.signal.keytransparency.client.E164MonitorRequest; +import org.signal.keytransparency.client.E164SearchRequest; +import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc; +import org.signal.keytransparency.client.MonitorRequest; +import org.signal.keytransparency.client.MonitorResponse; +import org.signal.keytransparency.client.SearchRequest; +import org.signal.keytransparency.client.SearchResponse; +import org.signal.keytransparency.client.UsernameHashMonitorRequest; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.Optional; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI_IDENTITY_KEY; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.NUMBER; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.UNIDENTIFIED_ACCESS_KEY; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.USERNAME_HASH; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded; +import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; +import static org.whispersystems.textsecuregcm.grpc.KeyTransparencyGrpcService.COMMITMENT_INDEX_LENGTH; + +@SuppressWarnings("OptionalUsedAsFieldOrParameterType") +public class KeyTransparencyGrpcServiceTest extends SimpleBaseGrpcTest{ + @Mock + private KeyTransparencyServiceClient keyTransparencyServiceClient; + @Mock + private RateLimiter rateLimiter; + + @Override + protected KeyTransparencyGrpcService createServiceBeforeEachTest() { + final RateLimiters rateLimiters = mock(RateLimiters.class); + when(rateLimiters.getKeyTransparencySearchLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getKeyTransparencyDistinguishedLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getKeyTransparencyMonitorLimiter()).thenReturn(rateLimiter); + + return new KeyTransparencyGrpcService(rateLimiters, keyTransparencyServiceClient); + } + + @Override + protected KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub createStub(final Channel channel) { + return KeyTransparencyQueryServiceGrpc.newBlockingStub(channel); + } + + @Test + void searchSuccess() throws RateLimitExceededException { + when(keyTransparencyServiceClient.search(any())).thenReturn(SearchResponse.getDefaultInstance()); + Mockito.doNothing().when(rateLimiter).validate(any(String.class)); + final SearchRequest request = SearchRequest.newBuilder() + .setAci(ByteString.copyFrom(ACI.toCompactByteArray())) + .setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize())) + .setConsistency(ConsistencyParameters.newBuilder() + .setDistinguished(10) + .build()) + .build(); + + assertDoesNotThrow(() -> unauthenticatedServiceStub().search(request)); + verify(keyTransparencyServiceClient, times(1)).search(eq(request)); + } + + @ParameterizedTest + @MethodSource + void searchInvalidRequest(final Optional aciServiceIdentifier, + final Optional aciIdentityKey, + final Optional e164, + final Optional unidentifiedAccessKey, + final Optional usernameHash, + final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize) { + + final SearchRequest.Builder requestBuilder = SearchRequest.newBuilder(); + + aciServiceIdentifier.ifPresent(v -> requestBuilder.setAci(ByteString.copyFrom(v))); + aciIdentityKey.ifPresent(v -> requestBuilder.setAciIdentityKey(ByteString.copyFrom(v.serialize()))); + usernameHash.ifPresent(v -> requestBuilder.setUsernameHash(ByteString.copyFrom(v))); + + final E164SearchRequest.Builder e164RequestBuilder = E164SearchRequest.newBuilder(); + + e164.ifPresent(e164RequestBuilder::setE164); + unidentifiedAccessKey.ifPresent(v -> e164RequestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(v))); + requestBuilder.setE164SearchRequest(e164RequestBuilder.build()); + + final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder(); + distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished); + lastTreeHeadSize.ifPresent(consistencyBuilder::setLast); + requestBuilder.setConsistency(consistencyBuilder.build()); + + assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().search(requestBuilder.build())); + verifyNoInteractions(keyTransparencyServiceClient); + } + + private static Stream searchInvalidRequest() { + byte[] aciBytes = ACI.toCompactByteArray(); + return Stream.of( + Arguments.argumentSet("Empty ACI", Optional.empty(), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)), + Arguments.argumentSet("Null ACI identity key", Optional.of(aciBytes), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)), + Arguments.argumentSet("Invalid ACI", Optional.of(new byte[15]), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)), + Arguments.argumentSet("Non-positive consistency.last", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)), + Arguments.argumentSet("consistency.distinguished not provided",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), + Arguments.argumentSet("Non-positive consistency.distinguished",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L)), + Arguments.argumentSet("E164 can't be provided without an unidentified access key", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.of(NUMBER), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)), + Arguments.argumentSet("Unidentified access key can't be provided without E164", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.of(UNIDENTIFIED_ACCESS_KEY), Optional.empty(), Optional.empty(), Optional.of(4L)), + Arguments.argumentSet("Invalid username hash", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.of(new byte[19]), Optional.empty(), Optional.of(4L)) + ); + } + + @Test + void searchRatelimited() throws RateLimitExceededException { + final Duration retryAfterDuration = Duration.ofMinutes(7); + Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class)); + + final SearchRequest request = SearchRequest.newBuilder() + .setAci(ByteString.copyFrom(ACI.toCompactByteArray())) + .setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize())) + .setConsistency(ConsistencyParameters.newBuilder() + .setDistinguished(10) + .build()) + .build(); + assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().search(request)); + verifyNoInteractions(keyTransparencyServiceClient); + } + + @Test + void monitorSuccess() { + when(keyTransparencyServiceClient.monitor(any())).thenReturn(MonitorResponse.getDefaultInstance()); + when(rateLimiter.validateReactive(any(String.class))) + .thenReturn(Mono.empty()); + final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder() + .setAci(ByteString.copyFrom(ACI.toCompactByteArray())) + .setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH])) + .setEntryPosition(10) + .build(); + + final MonitorRequest request = MonitorRequest.newBuilder() + .setAci(aciMonitorRequest) + .setConsistency(ConsistencyParameters.newBuilder() + .setDistinguished(10) + .setLast(10) + .build()) + .build(); + + assertDoesNotThrow(() -> unauthenticatedServiceStub().monitor(request)); + verify(keyTransparencyServiceClient, times(1)).monitor(eq(request)); + } + + @ParameterizedTest + @MethodSource + void monitorInvalidRequest(final Optional aciMonitorRequest, + final Optional e164MonitorRequest, + final Optional usernameHashMonitorRequest, + final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize) { + + final MonitorRequest.Builder requestBuilder = MonitorRequest.newBuilder(); + + aciMonitorRequest.ifPresent(requestBuilder::setAci); + e164MonitorRequest.ifPresent(requestBuilder::setE164); + usernameHashMonitorRequest.ifPresent(requestBuilder::setUsernameHash); + + final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder(); + lastTreeHeadSize.ifPresent(consistencyBuilder::setLast); + distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished); + + requestBuilder.setConsistency(consistencyBuilder.build()); + + assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().monitor(requestBuilder.build())); + } + + private static Stream monitorInvalidRequest() { + final Optional validAciMonitorRequest = Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 10)); + return Stream.of( + Arguments.argumentSet("ACI monitor request can't be unset", Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("ACI can't be empty",Optional.of(AciMonitorRequest.newBuilder().build()), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Empty ACI on ACI monitor request",Optional.of(constructAciMonitorRequest(new byte[0], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid ACI", Optional.of(constructAciMonitorRequest(new byte[15], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid commitment index on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[31], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid entry position on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 0)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("E164 can't be blank", validAciMonitorRequest, Optional.of(constructE164MonitorRequest("", new byte[32], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid commitment index on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[31], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid entry position on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[32], 0)), Optional.empty(), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Username hash can't be empty", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[0], new byte[32], 10)), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid username hash length", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[31], new byte[32], 10)), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid commitment index on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[31], 10)), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("Invalid entry position on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[32], 0)), Optional.of(4L), Optional.of(4L)), + Arguments.argumentSet("consistency.last must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L), + Arguments.argumentSet("consistency.last must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)), + Arguments.argumentSet("consistency.distinguished must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L)), Optional.empty()), + Arguments.argumentSet("consistency.distinguished must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(0L)) + ); + } + + @Test + void monitorRatelimited() throws RateLimitExceededException { + final Duration retryAfterDuration = Duration.ofMinutes(7); + Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class)); + + final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder() + .setAci(ByteString.copyFrom(ACI.toCompactByteArray())) + .setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH])) + .setEntryPosition(10) + .build(); + + final MonitorRequest request = MonitorRequest.newBuilder() + .setAci(aciMonitorRequest) + .setConsistency(ConsistencyParameters.newBuilder() + .setDistinguished(10) + .setLast(10) + .build()) + .build(); + assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().monitor(request)); + verifyNoInteractions(keyTransparencyServiceClient); + } + + @Test + void distinguishedSuccess() { + when(keyTransparencyServiceClient.distinguished(any())).thenReturn(DistinguishedResponse.getDefaultInstance()); + when(rateLimiter.validateReactive(any(String.class))) + .thenReturn(Mono.empty()); + final DistinguishedRequest request = DistinguishedRequest.newBuilder().build(); + + assertDoesNotThrow(() -> unauthenticatedServiceStub().distinguished(request)); + verify(keyTransparencyServiceClient, times(1)).distinguished(eq(request)); + } + + @Test + void distinguishedInvalidRequest() { + final DistinguishedRequest request = DistinguishedRequest.newBuilder() + .setLast(0) + .build(); + + assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().distinguished(request)); + verifyNoInteractions(keyTransparencyServiceClient); + } + + @Test + void distinguishedRatelimited() throws RateLimitExceededException { + final Duration retryAfterDuration = Duration.ofMinutes(7); + Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class)); + + final DistinguishedRequest request = DistinguishedRequest.newBuilder() + .setLast(10) + .build(); + + assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().distinguished(request)); + verifyNoInteractions(keyTransparencyServiceClient); + } + + private static AciMonitorRequest constructAciMonitorRequest(final byte[] aci, final byte[] commitmentIndex, final long entryPosition) { + return AciMonitorRequest.newBuilder() + .setAci(ByteString.copyFrom(aci)) + .setCommitmentIndex(ByteString.copyFrom(commitmentIndex)) + .setEntryPosition(entryPosition) + .build(); + } + + private static E164MonitorRequest constructE164MonitorRequest(final String e164, final byte[] commitmentIndex, final long entryPosition) { + return E164MonitorRequest.newBuilder() + .setE164(e164) + .setCommitmentIndex(ByteString.copyFrom(commitmentIndex)) + .setEntryPosition(entryPosition) + .build(); + } + + private static UsernameHashMonitorRequest constructUsernameHashMonitorRequest(final byte[] usernameHash, final byte[] commitmentIndex, final long entryPosition) { + return UsernameHashMonitorRequest.newBuilder() + .setUsernameHash(ByteString.copyFrom(usernameHash)) + .setCommitmentIndex(ByteString.copyFrom(commitmentIndex)) + .setEntryPosition(entryPosition) + .build(); + } +}