diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java index 880b45ddf..8b4bb4a0e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java @@ -11,34 +11,9 @@ import io.dropwizard.auth.Auth; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; -import org.signal.keytransparency.client.MonitorKey; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorRequest; -import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorResponse; -import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchRequest; -import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchResponse; -import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; -import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; -import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; -import org.whispersystems.websocket.auth.ReadOnly; - -import javax.validation.Valid; -import javax.validation.constraints.NotNull; -import javax.ws.rs.BadRequestException; -import javax.ws.rs.ForbiddenException; -import javax.ws.rs.NotFoundException; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.ServerErrorException; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Duration; @@ -48,6 +23,35 @@ import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import javax.validation.Valid; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Positive; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.ForbiddenException; +import javax.ws.rs.GET; +import javax.ws.rs.NotFoundException; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.ServerErrorException; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.signal.keytransparency.client.MonitorKey; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.entities.KeyTransparencyDistinguishedKeyResponse; +import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorRequest; +import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorResponse; +import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchRequest; +import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchResponse; +import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; +import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/key-transparency") @Tag(name = "KeyTransparency") @@ -76,9 +80,10 @@ public class KeyTransparencyController { """ ) @ApiResponse(responseCode = "200", description = "All search key lookups were successful", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "Invalid request. See response for any available details.") @ApiResponse(responseCode = "403", description = "At least one search key lookup to value mapping was invalid") @ApiResponse(responseCode = "404", description = "At least one search key lookup did not find the key") - @ApiResponse(responseCode = "429", description = "Ratelimited") + @ApiResponse(responseCode = "429", description = "Rate-limited") @ApiResponse(responseCode = "422", description = "Invalid request format") @POST @Path("/search") @@ -145,8 +150,9 @@ public class KeyTransparencyController { """ ) @ApiResponse(responseCode = "200", description = "All search keys exist in the log", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "Invalid request. See response for any available details.") @ApiResponse(responseCode = "404", description = "At least one search key lookup did not find the key") - @ApiResponse(responseCode = "429", description = "Ratelimited") + @ApiResponse(responseCode = "429", description = "Rate-limited") @ApiResponse(responseCode = "422", description = "Invalid request format") @POST @Path("/monitor") @@ -191,6 +197,44 @@ public class KeyTransparencyController { return null; } + @Operation( + summary = "Get the current value of the distinguished key", + description = """ + Enforced unauthenticated endpoint. The response contains the distinguished tree head to prove consistency + against for future calls to `/search` and `/distinguished`. + """ + ) + @ApiResponse(responseCode = "200", description = "The `distinguished` search key exists in the log", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "Invalid request. See response for any available details.") + @ApiResponse(responseCode = "422", description = "Invalid request format") + @ApiResponse(responseCode = "429", description = "Rate-limited") + @GET + @Path("/distinguished") + @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP) + @Produces(MediaType.APPLICATION_JSON) + public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey( + @ReadOnly @Auth final Optional authenticatedAccount, + + @Parameter(description = "The distinguished tree head size returned by a previously verified call") + @QueryParam("lastTreeHeadSize") @Valid final Optional<@Positive Long> lastTreeHeadSize) { + + // Disallow clients from making authenticated requests to this endpoint + requireNotAuthenticated(authenticatedAccount); + + try { + return keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize, KEY_TRANSPARENCY_RPC_TIMEOUT) + .thenApply(KeyTransparencyDistinguishedKeyResponse::new) + .join(); + } catch (final CancellationException exception) { + LOGGER.error("Unexpected cancellation from key transparency service", exception); + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception); + } catch (final CompletionException exception) { + handleKeyTransparencyServiceError(exception); + } + // This is unreachable + return null; + } + private void handleKeyTransparencyServiceError(final CompletionException exception) { final Throwable unwrapped = ExceptionUtils.unwrap(exception); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyDistinguishedKeyResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyDistinguishedKeyResponse.java new file mode 100644 index 000000000..370c25f6c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyDistinguishedKeyResponse.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.NotNull; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; + +public record KeyTransparencyDistinguishedKeyResponse( + @NotNull + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + @Schema(description = "The response for the distinguished tree head encoded in standard un-padded base64") + byte[] distinguishedKeyResponse +) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java index 4931595c8..7f8abbc2f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java @@ -25,6 +25,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import org.signal.keytransparency.client.ConsistencyParameters; +import org.signal.keytransparency.client.DistinguishedRequest; import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc; import org.signal.keytransparency.client.MonitorKey; import org.signal.keytransparency.client.MonitorRequest; @@ -153,6 +154,16 @@ public class KeyTransparencyServiceClient implements Managed { .thenApply(AbstractMessageLite::toByteArray); } + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public CompletableFuture getDistinguishedKey(final Optional lastTreeHeadSize, final Duration timeout) { + final DistinguishedRequest request = lastTreeHeadSize.map( + last -> DistinguishedRequest.newBuilder().setLast(last).build()) + .orElseGet(DistinguishedRequest::getDefaultInstance); + return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)).distinguished(request), + callbackExecutor) + .thenApply(AbstractMessageLite::toByteArray); + } + private static Deadline toDeadline(final Duration timeout) { return Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 1d210a253..b7975cbb6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -48,6 +48,8 @@ public class RateLimiters extends BaseRateLimiters { CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))), EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))), + KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, + new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), diff --git a/service/src/main/proto/KeyTransparencyService.proto b/service/src/main/proto/KeyTransparencyService.proto index 4d717b757..3ba4c60f0 100644 --- a/service/src/main/proto/KeyTransparencyService.proto +++ b/service/src/main/proto/KeyTransparencyService.proto @@ -15,6 +15,13 @@ package kt_query; * to look up and monitor search keys. */ service KeyTransparencyQueryService { + /** + * An endpoint used by clients to retrieve the most recent distinguished tree + * head, which should be used to derive consistency parameters for + * subsequent Search and Monitor requests. It should be the first key + * transparency RPC a client calls. + */ + rpc Distinguished(DistinguishedRequest) returns (SearchResponse) {} /** * An endpoint used by clients to search for a given key in the transparency log. * The server returns proof that the search key exists in the log. @@ -48,6 +55,19 @@ message ConsistencyParameters { optional uint64 distinguished = 2; } +/** + * DistinguishedRequest looks up the most recent distinguished key in the + * transparency log. + */ +message DistinguishedRequest { + /** + * The tree size of the client's last verified distinguished request. With the + * exception of a client's very first request, this field should always be + * set. + */ + optional uint64 last = 1; +} + message SearchRequest { /** * The key to look up in the log tree. diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java index 58f23c3fe..d1e411c3f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java @@ -5,6 +5,23 @@ package org.whispersystems.textsecuregcm.controllers; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.ACI_PREFIX; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.E164_PREFIX; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.USERNAME_PREFIX; +import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.getFullSearchKeyByteString; + import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.net.HttpHeaders; import com.google.i18n.phonenumbers.PhoneNumberUtil; @@ -14,6 +31,23 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.ws.rs.client.Entity; +import javax.ws.rs.client.Invocation; +import javax.ws.rs.client.WebTarget; +import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -21,11 +55,13 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.entities.KeyTransparencyDistinguishedKeyResponse; import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorRequest; import org.whispersystems.textsecuregcm.entities.KeyTransparencyMonitorResponse; import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchRequest; @@ -41,39 +77,6 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; -import javax.ws.rs.client.Entity; -import javax.ws.rs.client.Invocation; -import javax.ws.rs.core.Response; -import java.io.UncheckedIOException; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; -import java.util.stream.Stream; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.ACI_PREFIX; -import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.E164_PREFIX; -import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.USERNAME_PREFIX; -import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.getFullSearchKeyByteString; - @ExtendWith(DropwizardExtensionsSupport.class) public class KeyTransparencyControllerTest { @@ -90,6 +93,7 @@ public class KeyTransparencyControllerTest { private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiter searchRatelimiter = mock(RateLimiter.class); private static final RateLimiter monitorRatelimiter = mock(RateLimiter.class); + private static final RateLimiter distinguishedRatelimiter = mock(RateLimiter.class); private final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) @@ -103,9 +107,10 @@ public class KeyTransparencyControllerTest { @BeforeEach void setup() { - when(rateLimiters.forDescriptor(eq(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP))).thenReturn(searchRatelimiter); - when(rateLimiters.forDescriptor(eq(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP))).thenReturn( - monitorRatelimiter); + when(rateLimiters.forDescriptor(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP)).thenReturn( + distinguishedRatelimiter); + when(rateLimiters.forDescriptor(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP)).thenReturn(searchRatelimiter); + when(rateLimiters.forDescriptor(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP)).thenReturn(monitorRatelimiter); } @AfterEach @@ -211,7 +216,7 @@ public class KeyTransparencyControllerTest { ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(400, response.getStatus()); } - verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); } @ParameterizedTest @@ -255,7 +260,7 @@ public class KeyTransparencyControllerTest { createSearchRequestJson(aci, e164, Optional.empty(), aciIdentityKey, unidentifiedAccessKey, lastTreeHeadSize, distinguishedTreeHeadSize)))) { assertEquals(422, response.getStatus()); - verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); } } @@ -277,7 +282,7 @@ public class KeyTransparencyControllerTest { } @Test - void searchRatelimited() { + void searchRateLimited() { MockUtils.updateRateLimiterResponseToFail( rateLimiters, RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP, "127.0.0.1", Duration.ofMinutes(10), true); final Invocation.Builder request = resources.getJerseyTest() @@ -286,7 +291,7 @@ public class KeyTransparencyControllerTest { try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty(), ACI_IDENTITY_KEY, Optional.empty(),Optional.empty(), Optional.empty())))) { assertEquals(429, response.getStatus()); - verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); } } @@ -326,7 +331,7 @@ public class KeyTransparencyControllerTest { Entity.json(createMonitorRequestJson(ACI, List.of(3L), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(400, response.getStatus()); - verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); } } @@ -364,7 +369,7 @@ public class KeyTransparencyControllerTest { .request(); try (Response response = request.post(Entity.json(requestJson))) { assertEquals(422, response.getStatus()); - verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); } } @@ -405,7 +410,7 @@ public class KeyTransparencyControllerTest { } @Test - void monitorRatelimited() { + void monitorRateLimited() { MockUtils.updateRateLimiterResponseToFail( rateLimiters, RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP, "127.0.0.1", Duration.ofMinutes(10), true); final Invocation.Builder request = resources.getJerseyTest() @@ -415,7 +420,100 @@ public class KeyTransparencyControllerTest { Entity.json(createMonitorRequestJson(ACI, List.of(3L), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(429, response.getStatus()); - verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); + verifyNoInteractions(keyTransparencyServiceClient); + } + } + + @ParameterizedTest + @CsvSource(", 1") + void distinguishedSuccess(@Nullable Long lastTreeHeadSize) { + when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) + .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); + + WebTarget webTarget = resources.getJerseyTest() + .target("/v1/key-transparency/distinguished"); + + if (lastTreeHeadSize != null) { + webTarget = webTarget.queryParam("lastTreeHeadSize", lastTreeHeadSize); + } + + try (Response response = webTarget.request().get()) { + assertEquals(200, response.getStatus()); + + final KeyTransparencyDistinguishedKeyResponse distinguishedKeyResponse = response.readEntity( + KeyTransparencyDistinguishedKeyResponse.class); + assertNotNull(distinguishedKeyResponse.distinguishedKeyResponse()); + + verify(keyTransparencyServiceClient, times(1)) + .getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)), + eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); + } + } + + @Test + void distinguishedAuthenticated() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/key-transparency/distinguished") + .request() + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.get()) { + assertEquals(400, response.getStatus()); + } + verifyNoInteractions(keyTransparencyServiceClient); + } + + @ParameterizedTest + @MethodSource + void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) { + when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) + .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/key-transparency/distinguished") + .request(); + try (Response response = request.get()) { + assertEquals(httpStatus, response.getStatus()); + verify(keyTransparencyServiceClient).getDistinguishedKey(any(), any()); + } + } + + private static Stream distinguishedGrpcErrors() { + return Stream.of( + Arguments.of(Status.NOT_FOUND, 404), + Arguments.of(Status.PERMISSION_DENIED, 403), + Arguments.of(Status.INVALID_ARGUMENT, 422), + Arguments.of(Status.UNKNOWN, 500) + ); + } + + @Test + void distinguishedInvalidRequest() { + when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) + .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/key-transparency/distinguished") + .queryParam("lastTreeHeadSize", -1) + .request(); + + try (Response response = request.get()) { + assertEquals(400, response.getStatus()); + + verifyNoInteractions(keyTransparencyServiceClient); + } + } + + @Test + void distinguishedRateLimited() { + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP, "127.0.0.1", Duration.ofMinutes(10), + true); + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/key-transparency/distinguished") + .request(); + try (Response response = request.get()) { + assertEquals(429, response.getStatus()); + verifyNoInteractions(keyTransparencyServiceClient); } }