Implement key transparency endpoints using `simple-grpc`

This commit is contained in:
Katherine 2025-06-24 14:01:35 -04:00 committed by GitHub
parent 51773f5709
commit 059caa4c57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 562 additions and 116 deletions

View File

@ -554,8 +554,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.maxThreads(2) .maxThreads(2)
.minThreads(2) .minThreads(2)
.build(); .build();
ExecutorService keyTransparencyCallbackExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "keyTransparency-%d"));
ExecutorService googlePlayBillingExecutor = environment.lifecycle() ExecutorService googlePlayBillingExecutor = environment.lifecycle()
.virtualExecutorService(name(getClass(), "googlePlayBilling-%d")); .virtualExecutorService(name(getClass(), "googlePlayBilling-%d"));
ExecutorService appleAppStoreExecutor = environment.lifecycle() ExecutorService appleAppStoreExecutor = environment.lifecycle()
@ -606,8 +604,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getKeyTransparencyServiceConfiguration().port(), config.getKeyTransparencyServiceConfiguration().port(),
config.getKeyTransparencyServiceConfiguration().tlsCertificate(), config.getKeyTransparencyServiceConfiguration().tlsCertificate(),
config.getKeyTransparencyServiceConfiguration().clientCertificate(), config.getKeyTransparencyServiceConfiguration().clientCertificate(),
config.getKeyTransparencyServiceConfiguration().clientPrivateKey().value(), config.getKeyTransparencyServiceConfiguration().clientPrivateKey().value());
keyTransparencyCallbackExecutor);
SecureValueRecovery2Client secureValueRecovery2Client = new SecureValueRecovery2Client(svr2CredentialsGenerator, SecureValueRecovery2Client secureValueRecovery2Client = new SecureValueRecovery2Client(svr2CredentialsGenerator,
secureValueRecovery2ServiceExecutor, secureValueRecoveryServiceRetryExecutor, config.getSvr2Configuration()); secureValueRecovery2ServiceExecutor, secureValueRecoveryServiceRetryExecutor, config.getSvr2Configuration());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,

View File

@ -31,8 +31,7 @@ import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CancellationException; import org.glassfish.jersey.server.ManagedAsync;
import java.util.concurrent.CompletionException;
import org.signal.keytransparency.client.AciMonitorRequest; import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.E164MonitorRequest; import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.E164SearchRequest;
@ -48,15 +47,12 @@ import org.whispersystems.textsecuregcm.entities.KeyTransparencySearchResponse;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient; import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
@Path("/v1/key-transparency") @Path("/v1/key-transparency")
@Tag(name = "KeyTransparency") @Tag(name = "KeyTransparency")
public class KeyTransparencyController { public class KeyTransparencyController {
private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class); private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class);
@VisibleForTesting
static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15);
private final KeyTransparencyServiceClient keyTransparencyServiceClient; private final KeyTransparencyServiceClient keyTransparencyServiceClient;
public KeyTransparencyController( public KeyTransparencyController(
@ -88,6 +84,7 @@ public class KeyTransparencyController {
@Path("/search") @Path("/search")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP) @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencySearchResponse search( public KeyTransparencySearchResponse search(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount, @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencySearchRequest request) { @NotNull @Valid final KeyTransparencySearchRequest request) {
@ -104,19 +101,17 @@ public class KeyTransparencyController {
.build() .build()
)); ));
return keyTransparencyServiceClient.search( return new KeyTransparencySearchResponse(
keyTransparencyServiceClient.search(
ByteString.copyFrom(request.aci().toCompactByteArray()), ByteString.copyFrom(request.aci().toCompactByteArray()),
ByteString.copyFrom(request.aciIdentityKey().serialize()), ByteString.copyFrom(request.aciIdentityKey().serialize()),
request.usernameHash().map(ByteString::copyFrom), request.usernameHash().map(ByteString::copyFrom),
maybeE164SearchRequest, maybeE164SearchRequest,
request.lastTreeHeadSize(), request.lastTreeHeadSize(),
request.distinguishedTreeHeadSize(), request.distinguishedTreeHeadSize())
KEY_TRANSPARENCY_RPC_TIMEOUT) .toByteArray());
.thenApply(KeyTransparencySearchResponse::new).join(); } catch (final StatusRuntimeException exception) {
} catch (final CancellationException exception) { LOGGER.error("Unexpected error calling key transparency service", exception);
LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
handleKeyTransparencyServiceError(exception); handleKeyTransparencyServiceError(exception);
} }
// This is unreachable // This is unreachable
@ -140,6 +135,7 @@ public class KeyTransparencyController {
@Path("/monitor") @Path("/monitor")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP) @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencyMonitorResponse monitor( public KeyTransparencyMonitorResponse monitor(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount, @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencyMonitorRequest request) { @NotNull @Valid final KeyTransparencyMonitorRequest request) {
@ -173,13 +169,10 @@ public class KeyTransparencyController {
usernameHashMonitorRequest, usernameHashMonitorRequest,
e164MonitorRequest, e164MonitorRequest,
request.lastNonDistinguishedTreeHeadSize(), request.lastNonDistinguishedTreeHeadSize(),
request.lastDistinguishedTreeHeadSize(), request.lastDistinguishedTreeHeadSize())
KEY_TRANSPARENCY_RPC_TIMEOUT).join()); .toByteArray());
} catch (final StatusRuntimeException exception) {
} catch (final CancellationException exception) { LOGGER.error("Unexpected error calling key transparency service", exception);
LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
handleKeyTransparencyServiceError(exception); handleKeyTransparencyServiceError(exception);
} }
// This is unreachable // This is unreachable
@ -202,6 +195,7 @@ public class KeyTransparencyController {
@Path("/distinguished") @Path("/distinguished")
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP) @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ManagedAsync
public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey( public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey(
@Auth final Optional<AuthenticatedDevice> authenticatedAccount, @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@ -212,34 +206,26 @@ public class KeyTransparencyController {
requireNotAuthenticated(authenticatedAccount); requireNotAuthenticated(authenticatedAccount);
try { try {
return keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize, KEY_TRANSPARENCY_RPC_TIMEOUT) return new KeyTransparencyDistinguishedKeyResponse(
.thenApply(KeyTransparencyDistinguishedKeyResponse::new) keyTransparencyServiceClient.getDistinguishedKey(lastTreeHeadSize)
.join(); .toByteArray());
} catch (final CancellationException exception) { } catch (final StatusRuntimeException exception) {
LOGGER.error("Unexpected cancellation from key transparency service", exception); LOGGER.error("Unexpected error calling key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
} catch (final CompletionException exception) {
handleKeyTransparencyServiceError(exception); handleKeyTransparencyServiceError(exception);
} }
// This is unreachable // This is unreachable
return null; return null;
} }
private void handleKeyTransparencyServiceError(final CompletionException exception) { private void handleKeyTransparencyServiceError(final StatusRuntimeException exception) {
final Throwable unwrapped = ExceptionUtils.unwrap(exception); final Status.Code code = exception.getStatus().getCode();
final String description = exception.getStatus().getDescription();
if (unwrapped instanceof StatusRuntimeException e) { switch (code) {
final Status.Code code = e.getStatus().getCode(); case NOT_FOUND -> throw new NotFoundException(description);
final String description = e.getStatus().getDescription(); case PERMISSION_DENIED -> throw new ForbiddenException(description);
switch (code) { case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422);
case NOT_FOUND -> throw new NotFoundException(description); default -> throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, exception);
case PERMISSION_DENIED -> throw new ForbiddenException(description);
case INVALID_ARGUMENT -> throw new WebApplicationException(description, 422);
default -> throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped);
}
} }
LOGGER.error("Unexpected key transparency service failure", unwrapped);
throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped);
} }
private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) { private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) {

View File

@ -0,0 +1,140 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.Status;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.SimpleKeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
public class KeyTransparencyGrpcService extends
SimpleKeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceImplBase {
@VisibleForTesting
static final int COMMITMENT_INDEX_LENGTH = 32;
private final RateLimiters rateLimiters;
private final KeyTransparencyServiceClient client;
public KeyTransparencyGrpcService(final RateLimiters rateLimiters,
final KeyTransparencyServiceClient client) {
this.rateLimiters = rateLimiters;
this.client = client;
}
@Override
public SearchResponse search(final SearchRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencySearchLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
return client.search(validateSearchRequest(request));
}
@Override
public MonitorResponse monitor(final MonitorRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencyMonitorLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
return client.monitor(validateMonitorRequest(request));
}
@Override
public DistinguishedResponse distinguished(final DistinguishedRequest request) throws RateLimitExceededException {
rateLimiters.getKeyTransparencyDistinguishedLimiter().validate(RequestAttributesUtil.getRemoteAddress().getHostAddress());
// A client's very first distinguished request will not have a "last" parameter
if (request.hasLast() && request.getLast() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException();
}
return client.distinguished(request);
}
private SearchRequest validateSearchRequest(final SearchRequest request) {
if (request.hasE164SearchRequest()) {
final E164SearchRequest e164SearchRequest = request.getE164SearchRequest();
if (e164SearchRequest.getUnidentifiedAccessKey().isEmpty() != e164SearchRequest.getE164().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("Unidentified access key and E164 must be provided together or not at all").asRuntimeException();
}
}
if (!request.getConsistency().hasDistinguished()) {
throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished tree head size").asRuntimeException();
}
validateConsistencyParameters(request.getConsistency());
return request;
}
private MonitorRequest validateMonitorRequest(final MonitorRequest request) {
final AciMonitorRequest aciMonitorRequest = request.getAci();
try {
AciServiceIdentifier.fromBytes(aciMonitorRequest.getAci().toByteArray());
} catch (IllegalArgumentException e) {
throw Status.INVALID_ARGUMENT.withDescription("Invalid ACI").asRuntimeException();
}
if (aciMonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Aci entry position must be positive").asRuntimeException();
}
if (aciMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Aci commitment index must be 32 bytes").asRuntimeException();
}
if (request.hasUsernameHash()) {
final UsernameHashMonitorRequest usernameHashMonitorRequest = request.getUsernameHash();
if (usernameHashMonitorRequest.getUsernameHash().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash cannot be empty").asRuntimeException();
}
if (usernameHashMonitorRequest.getUsernameHash().size() != AccountController.USERNAME_HASH_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Invalid username hash length").asRuntimeException();
}
if (usernameHashMonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash entry position must be positive").asRuntimeException();
}
if (usernameHashMonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("Username hash commitment index must be 32 bytes").asRuntimeException();
}
}
if (request.hasE164()) {
final E164MonitorRequest e164MonitorRequest = request.getE164();
if (e164MonitorRequest.getE164().isEmpty()) {
throw Status.INVALID_ARGUMENT.withDescription("E164 cannot be empty").asRuntimeException();
}
if (e164MonitorRequest.getEntryPosition() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("E164 entry position must be positive").asRuntimeException();
}
if (e164MonitorRequest.getCommitmentIndex().size() != COMMITMENT_INDEX_LENGTH) {
throw Status.INVALID_ARGUMENT.withDescription("E164 commitment index must be 32 bytes").asRuntimeException();
}
}
if (!request.getConsistency().hasDistinguished() || !request.getConsistency().hasLast()) {
throw Status.INVALID_ARGUMENT.withDescription("Must provide distinguished and last tree head sizes").asRuntimeException();
}
validateConsistencyParameters(request.getConsistency());
return request;
}
private static void validateConsistencyParameters(final ConsistencyParameters consistency) {
if (consistency.getDistinguished() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Distinguished tree head size must be positive").asRuntimeException();
}
if (consistency.hasLast() && consistency.getLast() <= 0) {
throw Status.INVALID_ARGUMENT.withDescription("Last tree head size must be positive").asRuntimeException();
}
}
}

View File

@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.keytransparency; package org.whispersystems.textsecuregcm.keytransparency;
import com.google.protobuf.AbstractMessageLite;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
@ -20,44 +19,43 @@ import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Collection; import java.util.Collection;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.signal.keytransparency.client.AciMonitorRequest; import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters; import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest; import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest; import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc; import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.MonitorRequest; import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest; import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UsernameHashMonitorRequest; import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.CompletableFutureUtil;
public class KeyTransparencyServiceClient implements Managed { public class KeyTransparencyServiceClient implements Managed {
private static final String DAYS_UNTIL_CLIENT_CERTIFICATE_EXPIRATION_GAUGE_NAME = private static final String DAYS_UNTIL_CLIENT_CERTIFICATE_EXPIRATION_GAUGE_NAME =
MetricsUtil.name(KeyTransparencyServiceClient.class, "daysUntilClientCertificateExpiration"); MetricsUtil.name(KeyTransparencyServiceClient.class, "daysUntilClientCertificateExpiration");
private static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final Logger logger = LoggerFactory.getLogger(KeyTransparencyServiceClient.class); private static final Logger logger = LoggerFactory.getLogger(KeyTransparencyServiceClient.class);
private final Executor callbackExecutor;
private final String host; private final String host;
private final int port; private final int port;
private final ChannelCredentials tlsChannelCredentials; private final ChannelCredentials tlsChannelCredentials;
private ManagedChannel channel; private ManagedChannel channel;
private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceFutureStub stub; private KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub stub;
public KeyTransparencyServiceClient( public KeyTransparencyServiceClient(
final String host, final String host,
final int port, final int port,
final String tlsCertificate, final String tlsCertificate,
final String clientCertificate, final String clientCertificate,
final String clientPrivateKey, final String clientPrivateKey
final Executor callbackExecutor
) throws IOException { ) throws IOException {
this.host = host; this.host = host;
this.port = port; this.port = port;
@ -76,7 +74,6 @@ public class KeyTransparencyServiceClient implements Managed {
configureClientCertificateMetrics(clientCertificate); configureClientCertificateMetrics(clientCertificate);
} }
this.callbackExecutor = callbackExecutor;
} }
private void configureClientCertificateMetrics(String clientCertificate) { private void configureClientCertificateMetrics(String clientCertificate) {
@ -113,14 +110,13 @@ public class KeyTransparencyServiceClient implements Managed {
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> search( public SearchResponse search(
final ByteString aci, final ByteString aci,
final ByteString aciIdentityKey, final ByteString aciIdentityKey,
final Optional<ByteString> usernameHash, final Optional<ByteString> usernameHash,
final Optional<E164SearchRequest> e164SearchRequest, final Optional<E164SearchRequest> e164SearchRequest,
final Optional<Long> lastTreeHeadSize, final Optional<Long> lastTreeHeadSize,
final long distinguishedTreeHeadSize, final long distinguishedTreeHeadSize) {
final Duration timeout) {
final SearchRequest.Builder searchRequestBuilder = SearchRequest.newBuilder() final SearchRequest.Builder searchRequestBuilder = SearchRequest.newBuilder()
.setAci(aci) .setAci(aci)
.setAciIdentityKey(aciIdentityKey); .setAciIdentityKey(aciIdentityKey);
@ -133,19 +129,20 @@ public class KeyTransparencyServiceClient implements Managed {
lastTreeHeadSize.ifPresent(consistency::setLast); lastTreeHeadSize.ifPresent(consistency::setLast);
searchRequestBuilder.setConsistency(consistency.build()); searchRequestBuilder.setConsistency(consistency.build());
return search(searchRequestBuilder.build());
}
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) public SearchResponse search(final SearchRequest request) {
.search(searchRequestBuilder.build()), callbackExecutor) return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.thenApply(AbstractMessageLite::toByteArray); .search(request);
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> monitor(final AciMonitorRequest aciMonitorRequest, public MonitorResponse monitor(final AciMonitorRequest aciMonitorRequest,
final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest, final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest,
final Optional<E164MonitorRequest> e164MonitorRequest, final Optional<E164MonitorRequest> e164MonitorRequest,
final long lastTreeHeadSize, final long lastTreeHeadSize,
final long distinguishedTreeHeadSize, final long distinguishedTreeHeadSize) {
final Duration timeout) {
final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder() final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest) .setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder() .setConsistency(ConsistencyParameters.newBuilder()
@ -155,20 +152,26 @@ public class KeyTransparencyServiceClient implements Managed {
usernameHashMonitorRequest.ifPresent(monitorRequestBuilder::setUsernameHash); usernameHashMonitorRequest.ifPresent(monitorRequestBuilder::setUsernameHash);
e164MonitorRequest.ifPresent(monitorRequestBuilder::setE164); e164MonitorRequest.ifPresent(monitorRequestBuilder::setE164);
return monitor(monitorRequestBuilder.build());
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.monitor(monitorRequestBuilder.build()), callbackExecutor)
.thenApply(AbstractMessageLite::toByteArray);
} }
public MonitorResponse monitor(final MonitorRequest request) {
return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.monitor(request);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> getDistinguishedKey(final Optional<Long> lastTreeHeadSize, final Duration timeout) { public DistinguishedResponse getDistinguishedKey(final Optional<Long> lastTreeHeadSize) {
final DistinguishedRequest request = lastTreeHeadSize.map( final DistinguishedRequest request = lastTreeHeadSize.map(
last -> DistinguishedRequest.newBuilder().setLast(last).build()) last -> DistinguishedRequest.newBuilder().setLast(last).build())
.orElseGet(DistinguishedRequest::getDefaultInstance); .orElseGet(DistinguishedRequest::getDefaultInstance);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)).distinguished(request), return distinguished(request);
callbackExecutor) }
.thenApply(AbstractMessageLite::toByteArray);
public DistinguishedResponse distinguished(final DistinguishedRequest request) {
return stub.withDeadline(toDeadline(KEY_TRANSPARENCY_RPC_TIMEOUT))
.distinguished(request);
} }
private static Deadline toDeadline(final Duration timeout) { private static Deadline toDeadline(final Duration timeout) {
@ -180,7 +183,7 @@ public class KeyTransparencyServiceClient implements Managed {
channel = Grpc.newChannelBuilderForAddress(host, port, tlsChannelCredentials) channel = Grpc.newChannelBuilderForAddress(host, port, tlsChannelCredentials)
.idleTimeout(1, TimeUnit.MINUTES) .idleTimeout(1, TimeUnit.MINUTES)
.build(); .build();
stub = KeyTransparencyQueryServiceGrpc.newFutureStub(channel); stub = KeyTransparencyQueryServiceGrpc.newBlockingStub(channel);
} }
@Override @Override

View File

@ -206,4 +206,16 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public RateLimiter getWaitForTransferArchiveLimiter() { public RateLimiter getWaitForTransferArchiveLimiter() {
return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE); return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE);
} }
public RateLimiter getKeyTransparencySearchLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_SEARCH_PER_IP);
}
public RateLimiter getKeyTransparencyDistinguishedLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP);
}
public RateLimiter getKeyTransparencyMonitorLimiter() {
return forDescriptor(For.KEY_TRANSPARENCY_MONITOR_PER_IP);
}
} }

View File

@ -10,6 +10,8 @@ option java_package = "org.signal.keytransparency.client";
package kt_query; package kt_query;
import "org/signal/chat/require.proto";
/** /**
* An external-facing, read-only key transparency service used by Signal's chat server * An external-facing, read-only key transparency service used by Signal's chat server
* to look up and monitor identifiers. * to look up and monitor identifiers.
@ -19,8 +21,13 @@ package kt_query;
* - A username hash which also maps to an ACI * - A username hash which also maps to an ACI
* Separately, the log also stores and periodically updates a fixed value known as the `distinguished` key. * Separately, the log also stores and periodically updates a fixed value known as the `distinguished` key.
* Clients use the verified tree head from looking up this key for future calls to the Search and Monitor endpoints. * Clients use the verified tree head from looking up this key for future calls to the Search and Monitor endpoints.
*
* Note that this service definition is used in two different contexts:
* 1. Implementing the endpoints with rate-limiting and request validation
* 2. Using the generated client stub to forward requests to the remote key transparency service
*/ */
service KeyTransparencyQueryService { service KeyTransparencyQueryService {
option (org.signal.chat.require.auth) = AUTH_ONLY_ANONYMOUS;
/** /**
* An endpoint used by clients to retrieve the most recent distinguished tree * An endpoint used by clients to retrieve the most recent distinguished tree
* head, which should be used to derive consistency parameters for * head, which should be used to derive consistency parameters for
@ -44,15 +51,15 @@ message SearchRequest {
/** /**
* The ACI to look up in the log. * The ACI to look up in the log.
*/ */
bytes aci = 1; bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16];
/** /**
* The ACI identity key that the client thinks the ACI maps to in the log. * The ACI identity key that the client thinks the ACI maps to in the log.
*/ */
bytes aci_identity_key = 2; bytes aci_identity_key = 2 [(org.signal.chat.require.nonEmpty) = true];
/** /**
* The username hash to look up in the log. * The username hash to look up in the log.
*/ */
optional bytes username_hash = 3; optional bytes username_hash = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
/** /**
* The E164 to look up in the log along with associated data. * The E164 to look up in the log along with associated data.
*/ */
@ -60,7 +67,7 @@ message SearchRequest {
/** /**
* The tree head size(s) to prove consistency against. * The tree head size(s) to prove consistency against.
*/ */
ConsistencyParameters consistency = 5; ConsistencyParameters consistency = 5 [(org.signal.chat.require.present) = true];
} }
/** /**
@ -70,7 +77,7 @@ message E164SearchRequest {
/** /**
* The E164 that the client wishes to look up in the transparency log. * The E164 that the client wishes to look up in the transparency log.
*/ */
string e164 = 1; optional string e164 = 1 [(org.signal.chat.require.e164) = true];
/** /**
* The unidentified access key of the account associated with the provided E164. * The unidentified access key of the account associated with the provided E164.
*/ */
@ -328,28 +335,28 @@ message PrefixSearchResult {
} }
message MonitorRequest { message MonitorRequest {
AciMonitorRequest aci = 1; AciMonitorRequest aci = 1 [(org.signal.chat.require.present) = true];
optional UsernameHashMonitorRequest username_hash = 2; optional UsernameHashMonitorRequest username_hash = 2;
optional E164MonitorRequest e164 = 3; optional E164MonitorRequest e164 = 3;
ConsistencyParameters consistency = 4; ConsistencyParameters consistency = 4 [(org.signal.chat.require.present) = true];
} }
message AciMonitorRequest { message AciMonitorRequest {
bytes aci = 1; bytes aci = 1 [(org.signal.chat.require.exactlySize) = 16];
uint64 entry_position = 2; uint64 entry_position = 2;
bytes commitment_index = 3; bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 32];
} }
message UsernameHashMonitorRequest { message UsernameHashMonitorRequest {
bytes username_hash = 1; bytes username_hash = 1 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
uint64 entry_position = 2; uint64 entry_position = 2;
bytes commitment_index = 3; bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
} }
message E164MonitorRequest { message E164MonitorRequest {
string e164 = 1; optional string e164 = 1 [(org.signal.chat.require.e164) = true];
uint64 entry_position = 2; uint64 entry_position = 2;
bytes commitment_index = 3; bytes commitment_index = 3 [(org.signal.chat.require.exactlySize) = 0, (org.signal.chat.require.exactlySize) = 32];
} }
message MonitorProof { message MonitorProof {

View File

@ -35,12 +35,8 @@ import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.time.Duration; import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
@ -54,8 +50,10 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.keytransparency.client.CondensedTreeSearchResponse; import org.signal.keytransparency.client.CondensedTreeSearchResponse;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.FullTreeHead; import org.signal.keytransparency.client.FullTreeHead;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchProof; import org.signal.keytransparency.client.SearchProof;
import org.signal.keytransparency.client.SearchResponse; import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UpdateValue; import org.signal.keytransparency.client.UpdateValue;
@ -81,16 +79,16 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
public class KeyTransparencyControllerTest { public class KeyTransparencyControllerTest {
private static final String NUMBER = PhoneNumberUtil.getInstance().format( public static final String NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
private static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID()); public static final AciServiceIdentifier ACI = new AciServiceIdentifier(UUID.randomUUID());
private static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20); public static final byte[] USERNAME_HASH = TestRandomUtil.nextBytes(20);
private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
= new TestRemoteAddressFilterProvider("127.0.0.1"); = new TestRemoteAddressFilterProvider("127.0.0.1");
private static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); public static final IdentityKey ACI_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] COMMITMENT_INDEX = new byte[32]; private static final byte[] COMMITMENT_INDEX = new byte[32];
private static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16]; public static final byte[] UNIDENTIFIED_ACCESS_KEY = new byte[16];
private final KeyTransparencyServiceClient keyTransparencyServiceClient = mock(KeyTransparencyServiceClient.class); private final KeyTransparencyServiceClient keyTransparencyServiceClient = mock(KeyTransparencyServiceClient.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter searchRatelimiter = mock(RateLimiter.class); private static final RateLimiter searchRatelimiter = mock(RateLimiter.class);
@ -141,8 +139,8 @@ public class KeyTransparencyControllerTest {
e164.ifPresent(ignored -> searchResponseBuilder.setE164(CondensedTreeSearchResponse.getDefaultInstance())); e164.ifPresent(ignored -> searchResponseBuilder.setE164(CondensedTreeSearchResponse.getDefaultInstance()));
usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(CondensedTreeSearchResponse.getDefaultInstance())); usernameHash.ifPresent(ignored -> searchResponseBuilder.setUsernameHash(CondensedTreeSearchResponse.getDefaultInstance()));
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any())) when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(searchResponseBuilder.build().toByteArray())); .thenReturn(searchResponseBuilder.build());
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search") .target("/v1/key-transparency/search")
@ -167,8 +165,7 @@ public class KeyTransparencyControllerTest {
ArgumentCaptor<Optional<E164SearchRequest>> e164Argument = ArgumentCaptor.forClass(Optional.class); ArgumentCaptor<Optional<E164SearchRequest>> e164Argument = ArgumentCaptor.forClass(Optional.class);
verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(), verify(keyTransparencyServiceClient).search(aciArgument.capture(), aciIdentityKeyArgument.capture(),
usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L), usernameHashArgument.capture(), e164Argument.capture(), eq(Optional.of(3L)), eq(4L));
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray()); assertArrayEquals(ACI.toCompactByteArray(), aciArgument.getValue().toByteArray());
assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray()); assertArrayEquals(ACI_IDENTITY_KEY.serialize(), aciIdentityKeyArgument.getValue().toByteArray());
@ -218,8 +215,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void searchGrpcErrors(final Status grpcStatus, final int httpStatus) { void searchGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong(), any())) when(keyTransparencyServiceClient.search(any(), any(), any(), any(), any(), anyLong()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); .thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/search") .target("/v1/key-transparency/search")
@ -228,7 +225,7 @@ public class KeyTransparencyControllerTest {
Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(), Entity.json(createRequestJson(new KeyTransparencySearchRequest(ACI, Optional.empty(), Optional.empty(),
ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), 4L))))) { ACI_IDENTITY_KEY, Optional.empty(), Optional.empty(), 4L))))) {
assertEquals(httpStatus, response.getStatus()); assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong(), any()); verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any(), any(), anyLong());
} }
} }
@ -295,8 +292,8 @@ public class KeyTransparencyControllerTest {
@Test @Test
void monitorSuccess() { void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any())) when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); .thenReturn(MonitorResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor") .target("/v1/key-transparency/monitor")
@ -314,7 +311,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(keyTransparencyMonitorResponse.serializedResponse()); assertNotNull(keyTransparencyMonitorResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1)).monitor( verify(keyTransparencyServiceClient, times(1)).monitor(
any(), any(), any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); any(), any(), any(), eq(3L), eq(4L));
} }
} }
@ -337,8 +334,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) { void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any())) when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); .thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/monitor") .target("/v1/key-transparency/monitor")
@ -349,7 +346,7 @@ public class KeyTransparencyControllerTest {
new KeyTransparencyMonitorRequest.AciMonitor(ACI, 3, COMMITMENT_INDEX), new KeyTransparencyMonitorRequest.AciMonitor(ACI, 3, COMMITMENT_INDEX),
Optional.empty(), Optional.empty(), 3L, 4L))))) { Optional.empty(), Optional.empty(), 3L, 4L))))) {
assertEquals(httpStatus, response.getStatus()); assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong(), any()); verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong());
} }
} }
@ -500,8 +497,8 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest @ParameterizedTest
@CsvSource(", 1") @CsvSource(", 1")
void distinguishedSuccess(@Nullable Long lastTreeHeadSize) { void distinguishedSuccess(@Nullable Long lastTreeHeadSize) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); .thenReturn(DistinguishedResponse.getDefaultInstance());
WebTarget webTarget = resources.getJerseyTest() WebTarget webTarget = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished"); .target("/v1/key-transparency/distinguished");
@ -518,8 +515,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(distinguishedKeyResponse.serializedResponse()); assertNotNull(distinguishedKeyResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1)) verify(keyTransparencyServiceClient, times(1))
.getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)), .getDistinguishedKey(eq(Optional.ofNullable(lastTreeHeadSize)));
eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
} }
} }
@ -538,15 +534,15 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) { void distinguishedGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); .thenThrow(new StatusRuntimeException(grpcStatus));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished") .target("/v1/key-transparency/distinguished")
.request(); .request();
try (Response response = request.get()) { try (Response response = request.get()) {
assertEquals(httpStatus, response.getStatus()); assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient).getDistinguishedKey(any(), any()); verify(keyTransparencyServiceClient).getDistinguishedKey(any());
} }
} }
@ -561,8 +557,8 @@ public class KeyTransparencyControllerTest {
@Test @Test
void distinguishedInvalidRequest() { void distinguishedInvalidRequest() {
when(keyTransparencyServiceClient.getDistinguishedKey(any(), any())) when(keyTransparencyServiceClient.getDistinguishedKey(any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); .thenReturn(DistinguishedResponse.getDefaultInstance());
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/key-transparency/distinguished") .target("/v1/key-transparency/distinguished")

View File

@ -0,0 +1,305 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.Status;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.DistinguishedResponse;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.MonitorResponse;
import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.SearchResponse;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceClient;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import reactor.core.publisher.Mono;
import java.time.Duration;
import java.util.Optional;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.ACI_IDENTITY_KEY;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.NUMBER;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.UNIDENTIFIED_ACCESS_KEY;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyControllerTest.USERNAME_HASH;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertRateLimitExceeded;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import static org.whispersystems.textsecuregcm.grpc.KeyTransparencyGrpcService.COMMITMENT_INDEX_LENGTH;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class KeyTransparencyGrpcServiceTest extends SimpleBaseGrpcTest<KeyTransparencyGrpcService, KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub>{
@Mock
private KeyTransparencyServiceClient keyTransparencyServiceClient;
@Mock
private RateLimiter rateLimiter;
@Override
protected KeyTransparencyGrpcService createServiceBeforeEachTest() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.getKeyTransparencySearchLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getKeyTransparencyDistinguishedLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getKeyTransparencyMonitorLimiter()).thenReturn(rateLimiter);
return new KeyTransparencyGrpcService(rateLimiters, keyTransparencyServiceClient);
}
@Override
protected KeyTransparencyQueryServiceGrpc.KeyTransparencyQueryServiceBlockingStub createStub(final Channel channel) {
return KeyTransparencyQueryServiceGrpc.newBlockingStub(channel);
}
@Test
void searchSuccess() throws RateLimitExceededException {
when(keyTransparencyServiceClient.search(any())).thenReturn(SearchResponse.getDefaultInstance());
Mockito.doNothing().when(rateLimiter).validate(any(String.class));
final SearchRequest request = SearchRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize()))
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.build())
.build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().search(request));
verify(keyTransparencyServiceClient, times(1)).search(eq(request));
}
@ParameterizedTest
@MethodSource
void searchInvalidRequest(final Optional<byte[]> aciServiceIdentifier,
final Optional<IdentityKey> aciIdentityKey,
final Optional<String> e164,
final Optional<byte[]> unidentifiedAccessKey,
final Optional<byte[]> usernameHash,
final Optional<Long> lastTreeHeadSize,
final Optional<Long> distinguishedTreeHeadSize) {
final SearchRequest.Builder requestBuilder = SearchRequest.newBuilder();
aciServiceIdentifier.ifPresent(v -> requestBuilder.setAci(ByteString.copyFrom(v)));
aciIdentityKey.ifPresent(v -> requestBuilder.setAciIdentityKey(ByteString.copyFrom(v.serialize())));
usernameHash.ifPresent(v -> requestBuilder.setUsernameHash(ByteString.copyFrom(v)));
final E164SearchRequest.Builder e164RequestBuilder = E164SearchRequest.newBuilder();
e164.ifPresent(e164RequestBuilder::setE164);
unidentifiedAccessKey.ifPresent(v -> e164RequestBuilder.setUnidentifiedAccessKey(ByteString.copyFrom(v)));
requestBuilder.setE164SearchRequest(e164RequestBuilder.build());
final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder();
distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished);
lastTreeHeadSize.ifPresent(consistencyBuilder::setLast);
requestBuilder.setConsistency(consistencyBuilder.build());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().search(requestBuilder.build()));
verifyNoInteractions(keyTransparencyServiceClient);
}
private static Stream<Arguments> searchInvalidRequest() {
byte[] aciBytes = ACI.toCompactByteArray();
return Stream.of(
Arguments.argumentSet("Empty ACI", Optional.empty(), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Null ACI identity key", Optional.of(aciBytes), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Invalid ACI", Optional.of(new byte[15]), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Non-positive consistency.last", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)),
Arguments.argumentSet("consistency.distinguished not provided",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()),
Arguments.argumentSet("Non-positive consistency.distinguished",Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L)),
Arguments.argumentSet("E164 can't be provided without an unidentified access key", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.of(NUMBER), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Unidentified access key can't be provided without E164", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.of(UNIDENTIFIED_ACCESS_KEY), Optional.empty(), Optional.empty(), Optional.of(4L)),
Arguments.argumentSet("Invalid username hash", Optional.of(aciBytes), Optional.of(ACI_IDENTITY_KEY), Optional.empty(), Optional.empty(), Optional.of(new byte[19]), Optional.empty(), Optional.of(4L))
);
}
@Test
void searchRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final SearchRequest request = SearchRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setAciIdentityKey(ByteString.copyFrom(ACI_IDENTITY_KEY.serialize()))
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.build())
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().search(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any())).thenReturn(MonitorResponse.getDefaultInstance());
when(rateLimiter.validateReactive(any(String.class)))
.thenReturn(Mono.empty());
final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH]))
.setEntryPosition(10)
.build();
final MonitorRequest request = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.setLast(10)
.build())
.build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().monitor(request));
verify(keyTransparencyServiceClient, times(1)).monitor(eq(request));
}
@ParameterizedTest
@MethodSource
void monitorInvalidRequest(final Optional<AciMonitorRequest> aciMonitorRequest,
final Optional<E164MonitorRequest> e164MonitorRequest,
final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest,
final Optional<Long> lastTreeHeadSize,
final Optional<Long> distinguishedTreeHeadSize) {
final MonitorRequest.Builder requestBuilder = MonitorRequest.newBuilder();
aciMonitorRequest.ifPresent(requestBuilder::setAci);
e164MonitorRequest.ifPresent(requestBuilder::setE164);
usernameHashMonitorRequest.ifPresent(requestBuilder::setUsernameHash);
final ConsistencyParameters.Builder consistencyBuilder = ConsistencyParameters.newBuilder();
lastTreeHeadSize.ifPresent(consistencyBuilder::setLast);
distinguishedTreeHeadSize.ifPresent(consistencyBuilder::setDistinguished);
requestBuilder.setConsistency(consistencyBuilder.build());
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().monitor(requestBuilder.build()));
}
private static Stream<Arguments> monitorInvalidRequest() {
final Optional<AciMonitorRequest> validAciMonitorRequest = Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 10));
return Stream.of(
Arguments.argumentSet("ACI monitor request can't be unset", Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("ACI can't be empty",Optional.of(AciMonitorRequest.newBuilder().build()), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Empty ACI on ACI monitor request",Optional.of(constructAciMonitorRequest(new byte[0], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid ACI", Optional.of(constructAciMonitorRequest(new byte[15], new byte[32], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[31], 10)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on ACI monitor request", Optional.of(constructAciMonitorRequest(ACI.toCompactByteArray(), new byte[32], 0)), Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("E164 can't be blank", validAciMonitorRequest, Optional.of(constructE164MonitorRequest("", new byte[32], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[31], 10)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on E164 monitor request", validAciMonitorRequest, Optional.of(constructE164MonitorRequest(NUMBER, new byte[32], 0)), Optional.empty(), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Username hash can't be empty", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[0], new byte[32], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid username hash length", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(new byte[31], new byte[32], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid commitment index on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[31], 10)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("Invalid entry position on username hash monitor request", validAciMonitorRequest, Optional.empty(), Optional.of(constructUsernameHashMonitorRequest(USERNAME_HASH, new byte[32], 0)), Optional.of(4L), Optional.of(4L)),
Arguments.argumentSet("consistency.last must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(4L),
Arguments.argumentSet("consistency.last must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(0L), Optional.of(4L)),
Arguments.argumentSet("consistency.distinguished must be provided", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L)), Optional.empty()),
Arguments.argumentSet("consistency.distinguished must be positive", validAciMonitorRequest, Optional.empty(), Optional.empty(), Optional.of(4L), Optional.of(0L))
);
}
@Test
void monitorRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(ACI.toCompactByteArray()))
.setCommitmentIndex(ByteString.copyFrom(new byte[COMMITMENT_INDEX_LENGTH]))
.setEntryPosition(10)
.build();
final MonitorRequest request = MonitorRequest.newBuilder()
.setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
.setDistinguished(10)
.setLast(10)
.build())
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().monitor(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void distinguishedSuccess() {
when(keyTransparencyServiceClient.distinguished(any())).thenReturn(DistinguishedResponse.getDefaultInstance());
when(rateLimiter.validateReactive(any(String.class)))
.thenReturn(Mono.empty());
final DistinguishedRequest request = DistinguishedRequest.newBuilder().build();
assertDoesNotThrow(() -> unauthenticatedServiceStub().distinguished(request));
verify(keyTransparencyServiceClient, times(1)).distinguished(eq(request));
}
@Test
void distinguishedInvalidRequest() {
final DistinguishedRequest request = DistinguishedRequest.newBuilder()
.setLast(0)
.build();
assertStatusException(Status.INVALID_ARGUMENT, () -> unauthenticatedServiceStub().distinguished(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
@Test
void distinguishedRatelimited() throws RateLimitExceededException {
final Duration retryAfterDuration = Duration.ofMinutes(7);
Mockito.doThrow(new RateLimitExceededException(retryAfterDuration)).when(rateLimiter).validate(any(String.class));
final DistinguishedRequest request = DistinguishedRequest.newBuilder()
.setLast(10)
.build();
assertRateLimitExceeded(retryAfterDuration, () -> unauthenticatedServiceStub().distinguished(request));
verifyNoInteractions(keyTransparencyServiceClient);
}
private static AciMonitorRequest constructAciMonitorRequest(final byte[] aci, final byte[] commitmentIndex, final long entryPosition) {
return AciMonitorRequest.newBuilder()
.setAci(ByteString.copyFrom(aci))
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
private static E164MonitorRequest constructE164MonitorRequest(final String e164, final byte[] commitmentIndex, final long entryPosition) {
return E164MonitorRequest.newBuilder()
.setE164(e164)
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
private static UsernameHashMonitorRequest constructUsernameHashMonitorRequest(final byte[] usernameHash, final byte[] commitmentIndex, final long entryPosition) {
return UsernameHashMonitorRequest.newBuilder()
.setUsernameHash(ByteString.copyFrom(usernameHash))
.setCommitmentIndex(ByteString.copyFrom(commitmentIndex))
.setEntryPosition(entryPosition)
.build();
}
}