diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java index 6d13ec16c..b983005ae 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java @@ -57,7 +57,8 @@ import java.util.concurrent.CompletionException; public class KeyTransparencyController { private static final Logger LOGGER = LoggerFactory.getLogger(KeyTransparencyController.class); - private static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15); + @VisibleForTesting + static final Duration KEY_TRANSPARENCY_RPC_TIMEOUT = Duration.ofSeconds(15); private static final byte USERNAME_PREFIX = (byte) 'u'; private static final byte E164_PREFIX = (byte) 'n'; @VisibleForTesting @@ -95,12 +96,14 @@ public class KeyTransparencyController { final CompletableFuture aciSearchKeyResponseFuture = keyTransparencyServiceClient.search( getFullSearchKeyByteString(ACI_PREFIX, request.aci().toCompactByteArray()), request.lastTreeHeadSize(), + request.distinguishedTreeHeadSize(), KEY_TRANSPARENCY_RPC_TIMEOUT); final CompletableFuture e164SearchKeyResponseFuture = request.e164() .map(e164 -> keyTransparencyServiceClient.search( getFullSearchKeyByteString(E164_PREFIX, e164.getBytes(StandardCharsets.UTF_8)), request.lastTreeHeadSize(), + request.distinguishedTreeHeadSize(), KEY_TRANSPARENCY_RPC_TIMEOUT)) .orElse(CompletableFuture.completedFuture(null)); @@ -108,6 +111,7 @@ public class KeyTransparencyController { .map(usernameHash -> keyTransparencyServiceClient.search( getFullSearchKeyByteString(USERNAME_PREFIX, request.usernameHash().get()), request.lastTreeHeadSize(), + request.distinguishedTreeHeadSize(), KEY_TRANSPARENCY_RPC_TIMEOUT)) .orElse(CompletableFuture.completedFuture(null)); @@ -169,7 +173,8 @@ public class KeyTransparencyController { final MonitorResponse monitorResponse = keyTransparencyServiceClient.monitor( monitorKeys, - request.lastTreeHeadSize(), + request.lastNonDistinguishedTreeHeadSize(), + request.lastDistinguishedTreeHeadSize(), KEY_TRANSPARENCY_RPC_TIMEOUT).join(); MonitorProof usernameHashMonitorProof = null; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyMonitorRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyMonitorRequest.java index 6dd6e8ca9..a43ba2805 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyMonitorRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencyMonitorRequest.java @@ -45,7 +45,10 @@ public record KeyTransparencyMonitorRequest( Optional> usernameHashPositions, @Schema(description = "The tree head size to prove consistency against.") - Optional<@Positive Long> lastTreeHeadSize + Optional<@Positive Long> lastNonDistinguishedTreeHeadSize, + + @Schema(description = "The distinguished tree head size to prove consistency against.") + Optional<@Positive Long> lastDistinguishedTreeHeadSize ) { @AssertTrue diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencySearchRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencySearchRequest.java index e56af0827..85fd78537 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencySearchRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KeyTransparencySearchRequest.java @@ -33,6 +33,9 @@ public record KeyTransparencySearchRequest( @Schema(description = "The username hash to look up, encoded in web-safe unpadded base64.") Optional usernameHash, - @Schema(description = "The tree head size to prove consistency against.") - Optional<@Positive Long> lastTreeHeadSize + @Schema(description = "The non-distinguished tree head size to prove consistency against.") + Optional<@Positive Long> lastTreeHeadSize, + + @Schema(description = "The distinguished tree head size to prove consistency against.") + Optional<@Positive Long> distinguishedTreeHeadSize ) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java index ea5e9931d..b475bfa75 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/keytransparency/KeyTransparencyServiceClient.java @@ -18,6 +18,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import io.grpc.TlsChannelCredentials; +import katie.ConsistencyParameters; import katie.KatieGrpc; import katie.MonitorKey; import katie.MonitorRequest; @@ -55,11 +56,16 @@ public class KeyTransparencyServiceClient implements Managed { public CompletableFuture search( final ByteString searchKey, final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize, final Duration timeout) { final SearchRequest.Builder searchRequestBuilder = SearchRequest.newBuilder() .setSearchKey(searchKey); - lastTreeHeadSize.ifPresent(searchRequestBuilder::setLast); + final ConsistencyParameters.Builder consistency = ConsistencyParameters.newBuilder(); + lastTreeHeadSize.ifPresent(consistency::setLast); + distinguishedTreeHeadSize.ifPresent(consistency::setDistinguished); + + searchRequestBuilder.setConsistency(consistency); return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) .search(searchRequestBuilder.build()), callbackExecutor); @@ -67,11 +73,16 @@ public class KeyTransparencyServiceClient implements Managed { public CompletableFuture monitor(final List monitorKeys, final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize, final Duration timeout) { final MonitorRequest.Builder monitorRequestBuilder = MonitorRequest.newBuilder() .addAllContactKeys(monitorKeys); - lastTreeHeadSize.ifPresent(monitorRequestBuilder::setLast); + final ConsistencyParameters.Builder consistency = ConsistencyParameters.newBuilder(); + lastTreeHeadSize.ifPresent(consistency::setLast); + distinguishedTreeHeadSize.ifPresent(consistency::setDistinguished); + + monitorRequestBuilder.setConsistency(consistency); return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) .monitor(monitorRequestBuilder.build()), callbackExecutor); diff --git a/service/src/main/proto/KeyTransparencyService.proto b/service/src/main/proto/KeyTransparencyService.proto index bb3122947..e97f3d099 100644 --- a/service/src/main/proto/KeyTransparencyService.proto +++ b/service/src/main/proto/KeyTransparencyService.proto @@ -21,6 +21,27 @@ service Katie { rpc Monitor(MonitorRequest) returns (MonitorResponse) {} } +/** + * The tree head size(s) to prove consistency against. A client's very first + * key transparency request should be looking up the "distinguished" key; + * in this case, both fields will be omitted since the client has no previous + * tree heads to prove consistency against. + */ +message ConsistencyParameters { + /** + * The non-distinguished tree head size to prove consistency against. + * This field may be omitted if the client is looking up a search key + * for the first time. + */ + optional uint64 last = 1; + /** + * The distinguished tree head size to prove consistency against. + * This field may be omitted when the client is looking up the + * "distinguished" key for the very first time. + */ + optional uint64 distinguished = 2; +} + // TODO: add a `value` field so that the KT server can verify that the given search key is mapped // to the provided value. message SearchRequest { @@ -37,9 +58,9 @@ message SearchRequest { */ optional uint32 version = 2; /** - * The tree head size to prove consistency against. + * The tree head size(s) to prove consistency against. */ - optional uint64 last = 3; + ConsistencyParameters consistency = 3; } message SearchResponse { @@ -73,14 +94,18 @@ message FullTreeHead { * A representation of the log tree's current state signed by the key transparency service. */ TreeHead tree_head = 1; + /** + * A consistency proof between the current tree size and the requested distinguished tree size. + */ + repeated bytes distinguished = 2; /** * A consistency proof between the current tree size and the requested tree size. */ - repeated bytes consistency = 2; + repeated bytes consistency = 3; /** * A tree head signed by a third-party auditor. */ - optional AuditorTreeHead auditor_tree_head = 3; + optional AuditorTreeHead auditor_tree_head = 4; } message TreeHead { @@ -200,9 +225,9 @@ message MonitorRequest { */ repeated MonitorKey contact_keys = 2; /** - * The tree head size that the key transparency server must prove consistency against. + * The tree head size(s) to prove consistency against. */ - optional uint64 last = 3; + ConsistencyParameters consistency = 3; } message MonitorProof { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java index d0114682e..d76b36040 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyControllerTest.java @@ -62,7 +62,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @ExtendWith(DropwizardExtensionsSupport.class) @@ -119,15 +122,17 @@ public class KeyTransparencyControllerTest { @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @ParameterizedTest @MethodSource - void searchSuccess(final Optional e164, final Optional usernameHash) { + void searchSuccess(final Optional e164, final Optional usernameHash, final int expectedNumClientCalls) { final SearchResponse searchResponse = SearchResponse.newBuilder().build(); - when(keyTransparencyServiceClient.search(any(), any(), any())) + when(keyTransparencyServiceClient.search(any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(searchResponse)); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") .request(); - try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, e164, usernameHash)))) { + + final String searchJson = createSearchRequestJson(ACI, e164, usernameHash, Optional.of(3L), Optional.of(4L)); + try (Response response = request.post(Entity.json(searchJson))) { assertEquals(200, response.getStatus()); final KeyTransparencySearchResponse keyTransparencySearchResponse = response.readEntity( @@ -140,14 +145,17 @@ public class KeyTransparencyControllerTest { e164.ifPresentOrElse(ignored -> assertTrue(keyTransparencySearchResponse.e164SearchResponse().isPresent()), () -> assertTrue(keyTransparencySearchResponse.e164SearchResponse().isEmpty())); + + verify(keyTransparencyServiceClient, times(expectedNumClientCalls)).search(any(), eq(Optional.of(3L)), eq(Optional.of(4L)), + eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); } } private static Stream searchSuccess() { return Stream.of( - Arguments.of(Optional.empty(), Optional.empty()), - Arguments.of(Optional.empty(), Optional.of(TestRandomUtil.nextBytes(20))), - Arguments.of(Optional.of(NUMBER), Optional.empty()) + Arguments.of(Optional.empty(), Optional.empty(), 1), + Arguments.of(Optional.empty(), Optional.of(TestRandomUtil.nextBytes(20)), 2), + Arguments.of(Optional.of(NUMBER), Optional.empty(), 2) ); } @@ -157,22 +165,24 @@ public class KeyTransparencyControllerTest { .target("/v1/key-transparency/search") .request() .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); - try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty())))) { + try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(400, response.getStatus()); } + verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any()); } @ParameterizedTest @MethodSource void searchGrpcErrors(final Status grpcStatus, final int httpStatus) { - when(keyTransparencyServiceClient.search(any(), any(), any())) + when(keyTransparencyServiceClient.search(any(), any(), any(), any())) .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") .request(); - try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty())))) { + try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(httpStatus, response.getStatus()); + verify(keyTransparencyServiceClient, times(1)).search(any(), any(), any(), any()); } } @@ -185,18 +195,33 @@ public class KeyTransparencyControllerTest { ); } - @Test - void searchInvalidRequest() { + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @ParameterizedTest + @MethodSource + void searchInvalidRequest(final AciServiceIdentifier aci, + final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize) { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") .request(); try (Response response = request.post(Entity.json( - // ACI can't be null - createSearchRequestJson(null, Optional.empty(), Optional.empty())))) { + createSearchRequestJson(aci, Optional.empty(), Optional.empty(), lastTreeHeadSize, distinguishedTreeHeadSize)))) { assertEquals(422, response.getStatus()); + verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any()); } } + private static Stream searchInvalidRequest() { + return Stream.of( + // ACI can't be null + Arguments.of(null, Optional.empty(), Optional.empty()), + // lastNonDistinguishedTreeHeadSize must be positive + Arguments.of(ACI, Optional.of(0L), Optional.empty()), + // lastDistinguishedTreeHeadSize must be positive + Arguments.of(ACI, Optional.empty(), Optional.of(0L)) + ); + } + @Test void searchRatelimited() { MockUtils.updateRateLimiterResponseToFail( @@ -204,8 +229,9 @@ public class KeyTransparencyControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/key-transparency/search") .request(); - try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty())))) { + try (Response response = request.post(Entity.json(createSearchRequestJson(ACI, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(429, response.getStatus()); + verify(keyTransparencyServiceClient, never()).search(any(), any(), any(), any()); } } @@ -226,7 +252,7 @@ public class KeyTransparencyControllerTest { .addAllContactProofs(monitorProofs) .build(); - when(keyTransparencyServiceClient.monitor(any(), any(), any())) + when(keyTransparencyServiceClient.monitor(any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(monitorResponse)); final Invocation.Builder request = resources.getJerseyTest() @@ -237,7 +263,8 @@ public class KeyTransparencyControllerTest { createMonitorRequestJson( ACI, List.of(3L), usernameHash, usernameHashPositions, - e164, e164Positions)))) { + e164, e164Positions, + Optional.of(3L), Optional.of(4L))))) { assertEquals(200, response.getStatus()); final KeyTransparencyMonitorResponse keyTransparencyMonitorResponse = response.readEntity( @@ -250,6 +277,9 @@ public class KeyTransparencyControllerTest { e164.ifPresentOrElse(ignored -> assertTrue(keyTransparencyMonitorResponse.e164MonitorProof().isPresent()), () -> assertTrue(keyTransparencyMonitorResponse.e164MonitorProof().isEmpty())); + + verify(keyTransparencyServiceClient, times(1)).monitor( + any(), eq(Optional.of(3L)), eq(Optional.of(4L)), eq(KeyTransparencyController.KEY_TRANSPARENCY_RPC_TIMEOUT)); } } @@ -269,15 +299,16 @@ public class KeyTransparencyControllerTest { .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); try (Response response = request.post( Entity.json(createMonitorRequestJson(ACI, List.of(3L), Optional.empty(), Optional.empty(), - Optional.empty(), Optional.empty())))) { + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(400, response.getStatus()); + verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); } } @ParameterizedTest @MethodSource void monitorGrpcErrors(final Status grpcStatus, final int httpStatus) { - when(keyTransparencyServiceClient.monitor(any(), any(), any())) + when(keyTransparencyServiceClient.monitor(any(), any(), any(), any())) .thenReturn(CompletableFuture.failedFuture(new CompletionException(new StatusRuntimeException(grpcStatus)))); final Invocation.Builder request = resources.getJerseyTest() @@ -285,8 +316,9 @@ public class KeyTransparencyControllerTest { .request(); try (Response response = request.post( Entity.json(createMonitorRequestJson(ACI, List.of(3L), Optional.empty(), Optional.empty(), - Optional.empty(), Optional.empty())))) { + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(httpStatus, response.getStatus()); + verify(keyTransparencyServiceClient, times(1)).monitor(any(), any(), any(), any()); } } @@ -307,6 +339,7 @@ public class KeyTransparencyControllerTest { .request(); try (Response response = request.post(Entity.json(requestJson))) { assertEquals(422, response.getStatus()); + verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); } } @@ -314,29 +347,35 @@ public class KeyTransparencyControllerTest { return Stream.of( // aci and aciPositions can't be empty Arguments.of(createMonitorRequestJson(null, null, Optional.empty(), Optional.empty(), Optional.empty(), - Optional.empty())), + Optional.empty(), Optional.empty(), Optional.empty())), // aciPositions list can't be empty Arguments.of(createMonitorRequestJson(ACI, Collections.emptyList(), Optional.empty(), Optional.empty(), Optional.empty(), - Optional.empty())), + Optional.empty(), Optional.empty(), Optional.empty())), // usernameHash cannot be empty if usernameHashPositions isn't Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), Optional.of(List.of(5L)), - Optional.empty(), Optional.empty())), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())), // usernameHashPosition cannot be empty if usernameHash isn't Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.of(TestRandomUtil.nextBytes(20)), - Optional.empty(), Optional.empty(), Optional.empty())), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())), // usernameHashPositions list cannot be empty Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.of(TestRandomUtil.nextBytes(20)), - Optional.of(Collections.emptyList()), Optional.empty(), Optional.empty())), + Optional.of(Collections.emptyList()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())), // e164 cannot be empty if e164Positions isn't Arguments.of( createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(List.of(5L)))), + Optional.of(List.of(5L)), Optional.empty(), Optional.empty())), // e164Positions cannot be empty if e164 isn't Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), - Optional.empty(), Optional.of(NUMBER), Optional.empty())), + Optional.empty(), Optional.of(NUMBER), Optional.empty(), Optional.empty(), Optional.empty())), // e164Positions list cannot empty Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), - Optional.empty(), Optional.of(NUMBER), Optional.of(Collections.emptyList()))) + Optional.empty(), Optional.of(NUMBER), Optional.of(Collections.emptyList()), Optional.empty(), Optional.empty())), + // lastNonDistinguishedTreeHeadSize must be positive + Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(0L), Optional.empty())), + // lastDistinguishedTreeHeadSize must be positive + Arguments.of(createMonitorRequestJson(ACI, List.of(4L), Optional.empty(), + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(-1L))) ); } @@ -349,8 +388,9 @@ public class KeyTransparencyControllerTest { .request(); try (Response response = request.post( Entity.json(createMonitorRequestJson(ACI, List.of(3L), Optional.empty(), Optional.empty(), - Optional.empty(), Optional.empty())))) { + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())))) { assertEquals(429, response.getStatus()); + verify(keyTransparencyServiceClient, never()).monitor(any(), any(), any(), any()); } } @@ -365,9 +405,11 @@ public class KeyTransparencyControllerTest { final Optional usernameHash, final Optional> usernameHashPositions, final Optional e164, - final Optional> e164Positions) { + final Optional> e164Positions, + final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize) { final KeyTransparencyMonitorRequest request = new KeyTransparencyMonitorRequest(aci, aciPositions, - e164, e164Positions, usernameHash, usernameHashPositions, Optional.empty()); + e164, e164Positions, usernameHash, usernameHashPositions, lastTreeHeadSize, distinguishedTreeHeadSize); try { return SystemMapper.jsonMapper().writeValueAsString(request); } catch (final JsonProcessingException e) { @@ -379,8 +421,10 @@ public class KeyTransparencyControllerTest { private static String createSearchRequestJson( final AciServiceIdentifier aci, final Optional e164, - final Optional usernameHash) { - final KeyTransparencySearchRequest request = new KeyTransparencySearchRequest(aci, e164, usernameHash, null); + final Optional usernameHash, + final Optional lastTreeHeadSize, + final Optional distinguishedTreeHeadSize) { + final KeyTransparencySearchRequest request = new KeyTransparencySearchRequest(aci, e164, usernameHash, lastTreeHeadSize, distinguishedTreeHeadSize); try { return SystemMapper.jsonMapper().writeValueAsString(request); } catch (final JsonProcessingException e) {