Update for KeyTransparencyQueryService.MonitorRequest changes

This commit is contained in:
Chris Eager 2024-11-05 18:07:56 -06:00 committed by Jon Chambers
parent 96a4d4c8ac
commit b182c3d86d
4 changed files with 78 additions and 111 deletions

View File

@ -15,10 +15,7 @@ import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -37,8 +34,10 @@ import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.MonitorKey; import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -60,12 +59,6 @@ public class KeyTransparencyController {
private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class); private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class);
@VisibleForTesting @VisibleForTesting
static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15); static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15);
@VisibleForTesting
static final byte USERNAME_PREFIX = (byte) 'u';
@VisibleForTesting
static final byte E164_PREFIX = (byte) 'n';
@VisibleForTesting
static final byte ACI_PREFIX = (byte) 'a';
private final KeyTransparencyServiceClient keyTransparencyServiceClient; private final KeyTransparencyServiceClient keyTransparencyServiceClient;
public KeyTransparencyController( public KeyTransparencyController(
@ -139,9 +132,10 @@ public class KeyTransparencyController {
identifiers. Enforced unauthenticated endpoint. identifiers. Enforced unauthenticated endpoint.
""" """
) )
@ApiResponse(responseCode = "200", description = "All search keys exist in the log", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "All identifiers exist in the log", useReturnTypeSchema = true)
@ApiResponse(responseCode = "400", description = "Invalid request. See response for any available details.") @ApiResponse(responseCode = "400", description = "Invalid request. See response for any available details.")
@ApiResponse(responseCode = "404", description = "At least one search key lookup did not find the key") @ApiResponse(responseCode = "403", description = "One or more of the provided commitment indexes did not match")
@ApiResponse(responseCode = "404", description = "At least one identifier was not found")
@ApiResponse(responseCode = "429", description = "Rate-limited") @ApiResponse(responseCode = "429", description = "Rate-limited")
@ApiResponse(responseCode = "422", description = "Invalid request format") @ApiResponse(responseCode = "422", description = "Invalid request format")
@POST @POST
@ -156,26 +150,34 @@ public class KeyTransparencyController {
requireNotAuthenticated(authenticatedAccount); requireNotAuthenticated(authenticatedAccount);
try { try {
final List<MonitorKey> monitorKeys = new ArrayList<>(List.of( final AciMonitorRequest aciMonitorRequest = AciMonitorRequest.newBuilder()
createMonitorKey(getFullSearchKeyByteString(ACI_PREFIX, request.aci().value().toCompactByteArray()), .setAci(ByteString.copyFrom(request.aci().value().toCompactByteArray()))
request.aci().positions(), .addAllEntries(request.aci().positions())
ByteString.copyFrom(request.aci().commitmentIndex())) .setCommitmentIndex(ByteString.copyFrom(request.aci().commitmentIndex()))
)); .build();
request.usernameHash().ifPresent(usernameHash -> final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest = request.usernameHash().map(usernameHash ->
monitorKeys.add(createMonitorKey(getFullSearchKeyByteString(USERNAME_PREFIX, usernameHash.value()), UsernameHashMonitorRequest.newBuilder()
usernameHash.positions(), ByteString.copyFrom(usernameHash.commitmentIndex())))); .setUsernameHash(ByteString.copyFrom(usernameHash.value()))
.addAllEntries(usernameHash.positions())
.setCommitmentIndex(ByteString.copyFrom(usernameHash.commitmentIndex()))
.build());
request.e164().ifPresent(e164 -> final Optional<E164MonitorRequest> e164MonitorRequest = request.e164().map(e164 ->
monitorKeys.add( E164MonitorRequest.newBuilder()
createMonitorKey(getFullSearchKeyByteString(E164_PREFIX, e164.value().getBytes(StandardCharsets.UTF_8)), .setE164(e164.value())
e164.positions(), ByteString.copyFrom(e164.commitmentIndex())))); .addAllEntries(e164.positions())
.setCommitmentIndex(ByteString.copyFrom(e164.commitmentIndex()))
.build());
return new KeyTransparencyMonitorResponse(keyTransparencyServiceClient.monitor( return new KeyTransparencyMonitorResponse(keyTransparencyServiceClient.monitor(
monitorKeys, aciMonitorRequest,
usernameHashMonitorRequest,
e164MonitorRequest,
request.lastNonDistinguishedTreeHeadSize(), request.lastNonDistinguishedTreeHeadSize(),
request.lastDistinguishedTreeHeadSize(), request.lastDistinguishedTreeHeadSize(),
KEY_TRANSPARENCY_RPC_TIMEOUT).join()); KEY_TRANSPARENCY_RPC_TIMEOUT).join());
} catch (final CancellationException exception) { } catch (final CancellationException exception) {
LOGGER.error("Unexpected cancellation from key transparency service", exception); LOGGER.error("Unexpected cancellation from key transparency service", exception);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception); throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE, exception);
@ -242,28 +244,10 @@ public class KeyTransparencyController {
throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped); throw new ServerErrorException(Response.Status.INTERNAL_SERVER_ERROR, unwrapped);
} }
private static MonitorKey createMonitorKey(final ByteString fullSearchKey, final List<Long> positions,
final ByteString commitmentIndex) {
return MonitorKey.newBuilder()
.setSearchKey(fullSearchKey)
.addAllEntries(positions)
.setCommitmentIndex(commitmentIndex)
.build();
}
private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) { private void requireNotAuthenticated(final Optional<AuthenticatedDevice> authenticatedAccount) {
if (authenticatedAccount.isPresent()) { if (authenticatedAccount.isPresent()) {
throw new BadRequestException("Endpoint requires unauthenticated access"); throw new BadRequestException("Endpoint requires unauthenticated access");
} }
} }
@VisibleForTesting
static ByteString getFullSearchKeyByteString(final byte prefix, final byte[] searchKeyBytes) {
final ByteBuffer fullSearchKeyBuffer = ByteBuffer.allocate(searchKeyBytes.length + 1);
fullSearchKeyBuffer.put(prefix);
fullSearchKeyBuffer.put(searchKeyBytes);
fullSearchKeyBuffer.flip();
return ByteString.copyFrom(fullSearchKeyBuffer.array());
}
} }

View File

@ -19,18 +19,19 @@ import java.security.cert.X509Certificate;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Collection; import java.util.Collection;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.signal.keytransparency.client.AciMonitorRequest;
import org.signal.keytransparency.client.ConsistencyParameters; import org.signal.keytransparency.client.ConsistencyParameters;
import org.signal.keytransparency.client.DistinguishedRequest; import org.signal.keytransparency.client.DistinguishedRequest;
import org.signal.keytransparency.client.E164MonitorRequest;
import org.signal.keytransparency.client.E164SearchRequest; import org.signal.keytransparency.client.E164SearchRequest;
import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc; import org.signal.keytransparency.client.KeyTransparencyQueryServiceGrpc;
import org.signal.keytransparency.client.MonitorKey;
import org.signal.keytransparency.client.MonitorRequest; import org.signal.keytransparency.client.MonitorRequest;
import org.signal.keytransparency.client.SearchRequest; import org.signal.keytransparency.client.SearchRequest;
import org.signal.keytransparency.client.UsernameHashMonitorRequest;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
@ -138,19 +139,22 @@ public class KeyTransparencyServiceClient implements Managed {
.thenApply(AbstractMessageLite::toByteArray); .thenApply(AbstractMessageLite::toByteArray);
} }
public CompletableFuture<byte[]> monitor(final List<MonitorKey> monitorKeys, @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public CompletableFuture<byte[]> monitor(final AciMonitorRequest aciMonitorRequest,
final Optional<UsernameHashMonitorRequest> usernameHashMonitorRequest,
final Optional<E164MonitorRequest> e164MonitorRequest,
final long lastTreeHeadSize, final long lastTreeHeadSize,
final long distinguishedTreeHeadSize, final long distinguishedTreeHeadSize,
final Duration timeout) { final Duration timeout) {
final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder() final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder()
.addAllContactKeys(monitorKeys); .setAci(aciMonitorRequest)
.setConsistency(ConsistencyParameters.newBuilder()
.setLast(lastTreeHeadSize)
.setDistinguished(distinguishedTreeHeadSize)
.build());
final ConsistencyParameters consistency = ConsistencyParameters.newBuilder() usernameHashMonitorRequest.ifPresent(monitorRequestBuilder::setUsernameHash);
.setLast(lastTreeHeadSize) e164MonitorRequest.ifPresent(monitorRequestBuilder::setE164);
.setDistinguished(distinguishedTreeHeadSize)
.build();
monitorRequestBuilder.setConsistency(consistency);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.monitor(monitorRequestBuilder.build()), callbackExecutor) .monitor(monitorRequestBuilder.build()), callbackExecutor)

View File

@ -34,7 +34,7 @@ service KeyTransparencyQueryService {
*/ */
rpc Search(SearchRequest) returns (SearchResponse) {} rpc Search(SearchRequest) returns (SearchResponse) {}
/** /**
* An endpoint that allows users to monitor a set of identifiers by returning proof that the log continues to be * An endpoint that allows users to monitor a group of identifiers by returning proof that the log continues to be
* constructed correctly in later entries for those identifiers. * constructed correctly in later entries for those identifiers.
*/ */
rpc Monitor(MonitorRequest) returns (MonitorResponse) {} rpc Monitor(MonitorRequest) returns (MonitorResponse) {}
@ -141,7 +141,7 @@ message DistinguishedRequest {
/** /**
* DistinguishedResponse contains the tree head and search proof for the most * DistinguishedResponse contains the tree head and search proof for the most
* recent `distinguished` key in the log. * recent `distinguished` key in the log.
*/ */
message DistinguishedResponse { message DistinguishedResponse {
/** /**
@ -286,46 +286,34 @@ message PrefixSearchResult {
uint32 counter = 2; uint32 counter = 2;
} }
message MonitorKey { message MonitorRequest {
/** AciMonitorRequest aci = 1;
* The key to search for in the log tree. optional UsernameHashMonitorRequest username_hash = 2;
*/ optional E164MonitorRequest e164 = 3;
bytes search_key = 1; ConsistencyParameters consistency = 4;
/** }
* A list of log tree positions maintained by a client for the identifier being monitored.
* Each position is in the direct path to a key version and corresponds to a tree head message AciMonitorRequest {
* that has been verified to contain that version or greater. bytes aci = 1;
* The key transparency server uses this list to compute which log entries to return
* in the corresponding MonitorProof.
*/
repeated uint64 entries = 2; repeated uint64 entries = 2;
/**
* The commitment index for the identifier. This is derived from vrf_proof in
* the SearchResponse.
*/
bytes commitment_index = 3; bytes commitment_index = 3;
} }
message UsernameHashMonitorRequest {
bytes username_hash = 1;
repeated uint64 entries = 2;
bytes commitment_index = 3;
}
message MonitorRequest { message E164MonitorRequest {
/** string e164 = 1;
* TODO: Remove this protobuf field in the KT server repeated uint64 entries = 2;
*/ bytes commitment_index = 3;
repeated MonitorKey owned_keys = 1;
/**
* The list of identifiers that the client would like to monitor.
* All identifiers *must* belong to the same user.
*/
repeated MonitorKey contact_keys = 2;
/**
* The tree head size(s) to prove consistency against.
*/
ConsistencyParameters consistency = 3;
} }
message MonitorProof { message MonitorProof {
/** /**
* Generated based on the monitored entries provided in MonitorKey.entries. Each ProofStep * Generated based on the monitored entry provided in MonitorRequest.entries. Each ProofStep
* corresponds to a log tree entry that exists in the search path to each monitored entry * corresponds to a log tree entry that exists in the search path to each monitored entry
* and that came *after* that monitored entry. It proves that the log tree has been constructed * and that came *after* that monitored entry. It proves that the log tree has been constructed
* correctly at that later entry. This list also includes any remaining entries * correctly at that later entry. This list also includes any remaining entries
@ -342,17 +330,22 @@ message MonitorResponse {
*/ */
FullTreeHead tree_head = 1; FullTreeHead tree_head = 1;
/** /**
* TODO: Remove this protobuf field in the KT server * A proof that the MonitorRequest's ACI continues to be constructed correctly in later entries of the log tree.
*/ */
repeated MonitorProof owned_proofs = 2; MonitorProof aci = 2;
/** /**
* A list of proofs, one for each identifier in MonitorRequest.contact_keys, each proving that the given identifier * A proof that the username hash continues to be constructed correctly in later entries of the log tree.
* continues to be constructed correctly in later entries of the log tree. * Will be absent if the request did not include a UsernameHashMonitorRequest.
*/ */
repeated MonitorProof contact_proofs = 3; optional MonitorProof username_hash = 3;
/**
* A proof that the e164 continues to be constructed correctly in later entries of the log tree.
* Will be absent if the request did not include a E164MonitorRequest.
*/
optional MonitorProof e164 = 4;
/** /**
* A batch inclusion proof that the log entries involved in the binary search for each of the entries * A batch inclusion proof that the log entries involved in the binary search for each of the entries
* being monitored in MonitorKey.entries are included in the current log tree. * being monitored in the request are included in the current log tree.
*/ */
repeated bytes inclusion = 4; repeated bytes inclusion = 5;
} }

View File

@ -18,8 +18,6 @@ import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.ACI_PREFIX;
import static org.whispersystems.textsecuregcm.controllers.KeyTransparencyController.getFullSearchKeyByteString;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
@ -124,18 +122,6 @@ public class KeyTransparencyControllerTest {
monitorRatelimiter); monitorRatelimiter);
} }
@Test
void getFullSearchKey() {
final byte[] charBytes = new byte[]{ACI_PREFIX};
final byte[] aci = ACI.toCompactByteArray();
final byte[] expectedFullSearchKey = new byte[aci.length + 1];
System.arraycopy(charBytes, 0, expectedFullSearchKey, 0, charBytes.length);
System.arraycopy(aci, 0, expectedFullSearchKey, charBytes.length, aci.length);
assertArrayEquals(expectedFullSearchKey, getFullSearchKeyByteString(ACI_PREFIX, aci).toByteArray());
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
@ -309,7 +295,7 @@ public class KeyTransparencyControllerTest {
@Test @Test
void monitorSuccess() { void monitorSuccess() {
when(keyTransparencyServiceClient.monitor(any(), anyLong(), anyLong(), any())) when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16))); .thenReturn(CompletableFuture.completedFuture(TestRandomUtil.nextBytes(16)));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -328,7 +314,7 @@ public class KeyTransparencyControllerTest {
assertNotNull(keyTransparencyMonitorResponse.serializedResponse()); assertNotNull(keyTransparencyMonitorResponse.serializedResponse());
verify(keyTransparencyServiceClient, times(1)).monitor( verify(keyTransparencyServiceClient, times(1)).monitor(
any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); any(), any(), any(), eq(3L), eq(4L), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT));
} }
} }
@ -351,7 +337,7 @@ public class KeyTransparencyControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) { void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) {
when(keyTransparencyServiceClient.monitor(any(), anyLong(), anyLong(), any())) when(keyTransparencyServiceClient.monitor(any(), any(), any(), anyLong(), anyLong(), any()))
.thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus))));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -363,7 +349,7 @@ public class KeyTransparencyControllerTest {
new KeyTransparencyMonitorRequest.AciMonitor(ACI, List.of(3L), COMMITMENT_INDEX), new KeyTransparencyMonitorRequest.AciMonitor(ACI, List.of(3L), COMMITMENT_INDEX),
Optional.empty(), Optional.empty(), 3L, 4L))))) { Optional.empty(), Optional.empty(), 3L, 4L))))) {
assertEquals(httpStatus, response.getStatus()); assertEquals(httpStatus, response.getStatus());
verify(keyTransparencyServiceClient, times(1)).monitor(any(), anyLong(), anyLong(), any()); verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), anyLong(), anyLong(), any());
} }
} }