diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 7907a4379..961a6803e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -16,6 +16,9 @@ import io.swagger.v3.oas.annotations.headers.Header; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -29,6 +32,7 @@ import javax.ws.rs.Consumes; import javax.ws.rs.DefaultValue; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; +import javax.ws.rs.POST; import javax.ws.rs.PUT; import javax.ws.rs.Path; import javax.ws.rs.PathParam; @@ -42,6 +46,7 @@ import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; +import org.whispersystems.textsecuregcm.entities.CheckKeysRequest; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; @@ -221,6 +226,97 @@ public class KeysController { } } + @POST + @Path("/check") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + @Operation(summary = "Check keys", description = """ + Checks that client and server have consistent views of repeated-use keys. For a given identity type, clients + submit a digest of their repeated-use key material. The digest is calculated as: + + SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes) + + …where the elements of the hash are: + + - identityKeyBytes: the serialized form of the client's public identity key as produced by libsignal (i.e. one + version byte followed by 32 bytes of key material for a total of 33 bytes) + - signedEcPreKeyId: an 8-byte, big-endian representation of the ID of the client's signed EC pre-key + - signedEcPreKeyBytes: the serialized form of the client's signed EC pre-key as produced by libsignal (i.e. one + version byte followed by 32 bytes of key material for a total of 33 bytes) + - lastResortKeyId: an 8-byte, big-endian representation of the ID of the client's last-resort Kyber key + - lastResortKeyBytes: the serialized form of the client's last-resort Kyber key as produced by libsignal (i.e. one + version byte followed by 1568 bytes of key material for a total of 1569 bytes) + """) + @ApiResponse(responseCode = "200", description = "Indicates that client and server have consistent views of repeated-use keys") + @ApiResponse(responseCode = "401", description = "Account authentication check failed") + @ApiResponse(responseCode = "409", description = """ + Indicates that client and server have inconsistent views of repeated-use keys or one or more repeated-use keys could + not be found + """) + @ApiResponse(responseCode = "422", description = "Invalid request format") + public CompletableFuture setKeys( + @ReadOnly @Auth final AuthenticatedAccount auth, + @RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest, + @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) { + + final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType()); + final byte deviceId = auth.getAuthenticatedDevice().getId(); + + final CompletableFuture> ecSignedPreKeyFuture = + keysManager.getEcSignedPreKey(identifier, deviceId); + + final CompletableFuture> lastResortKeyFuture = + keysManager.getLastResort(identifier, deviceId); + + return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture) + .thenApply(ignored -> { + final Optional maybeSignedPreKey = ecSignedPreKeyFuture.join(); + final Optional maybeLastResortKey = lastResortKeyFuture.join(); + + final boolean digestsMatch; + + if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) { + final IdentityKey identityKey = auth.getAccount().getIdentityKey(checkKeysRequest.identityType()); + final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get(); + final KEMSignedPreKey lastResortKey = maybeLastResortKey.get(); + + final MessageDigest messageDigest; + + try { + messageDigest = MessageDigest.getInstance("SHA-256"); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e); + } + + messageDigest.update(identityKey.serialize()); + + { + final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId()); + ecSignedPreKeyIdBuffer.flip(); + + messageDigest.update(ecSignedPreKeyIdBuffer); + messageDigest.update(ecSignedPreKey.serializedPublicKey()); + } + + { + final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + lastResortKeyIdBuffer.putLong(lastResortKey.keyId()); + lastResortKeyIdBuffer.flip(); + + messageDigest.update(lastResortKeyIdBuffer); + messageDigest.update(lastResortKey.serializedPublicKey()); + } + + digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest()); + } else { + digestsMatch = false; + } + + return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build(); + }); + } + @GET @Path("/{identifier}/{device_id}") @Produces(MediaType.APPLICATION_JSON) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/CheckKeysRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/CheckKeysRequest.java new file mode 100644 index 000000000..dc58cf7fd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/CheckKeysRequest.java @@ -0,0 +1,34 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.swagger.v3.oas.annotations.media.Schema; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.whispersystems.textsecuregcm.util.ExactlySize; + +public record CheckKeysRequest( + @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ + The identity type for which to check for a shared view of repeated-use keys + """) + IdentityType identityType, + + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + @ExactlySize(32) + @Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """ + A 32-byte digest of the client's repeated-use keys for the given identity type. The digest is calculated as: + + SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes) + + …where the elements of the hash are: + + - identityKeyBytes: the serialized form of the client's public identity key as produced by libsignal (i.e. one + version byte followed by 32 bytes of key material for a total of 33 bytes) + - signedEcPreKeyId: an 8-byte, big-endian representation of the ID of the client's signed EC pre-key + - signedEcPreKeyBytes: the serialized form of the client's signed EC pre-key as produced by libsignal (i.e. one + version byte followed by 32 bytes of key material for a total of 33 bytes) + - lastResortKeyId: an 8-byte, big-endian representation of the ID of the client's last-resort Kyber key + - lastResortKeyBytes: the serialized form of the client's last-resort Kyber key as produced by libsignal (i.e. + one version byte followed by 1568 bytes of key material for a total of 1569 bytes) + """) + byte[] digest) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 077dabe28..a6eed2238 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -106,8 +106,7 @@ public class KeysManager { .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))); } - @VisibleForTesting - CompletableFuture> getLastResort(final UUID identifier, final byte deviceId) { + public CompletableFuture> getLastResort(final UUID identifier, final byte deviceId) { return pqLastResortKeys.find(identifier, deviceId); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index fc6830361..a7349455e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -25,6 +25,9 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.ByteBuffer; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.util.Collections; import java.util.List; @@ -42,13 +45,15 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; -import org.junit.jupiter.params.provider.ValueSource; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.CheckKeysRequest; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; @@ -975,4 +980,162 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(422); } + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + @ParameterizedTest + @MethodSource + void checkKeys( + final IdentityKey clientIdentityKey, + final ECSignedPreKey clientEcSignedPreKey, + final Optional serverEcSignedPreKey, + final KEMSignedPreKey clientLastResortKey, + final Optional serverLastResortKey, + final int expectedStatus) throws NoSuchAlgorithmException { + + when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(serverEcSignedPreKey)); + + when(KEYS.getLastResort(AuthHelper.VALID_UUID, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(serverLastResortKey)); + + final CheckKeysRequest checkKeysRequest = + new CheckKeysRequest(IdentityType.ACI, getKeyDigest(clientIdentityKey, clientEcSignedPreKey, clientLastResortKey)); + + try (final Response response = + resources.getJerseyTest() + .target("/v2/keys/check") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .post(Entity.entity(checkKeysRequest, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(expectedStatus, response.getStatus()); + } + } + + private static List checkKeys() { + final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(17, AuthHelper.VALID_IDENTITY_KEY_PAIR); + final KEMSignedPreKey lastResortKey = KeysHelper.signedKEMPreKey(19, AuthHelper.VALID_IDENTITY_KEY_PAIR); + + return List.of( + // All keys match + Arguments.of( + AuthHelper.VALID_IDENTITY, + ecSignedPreKey, + Optional.of(ecSignedPreKey), + lastResortKey, + Optional.of(lastResortKey), + 200), + + // Signed EC pre-key not found + Arguments.of( + AuthHelper.VALID_IDENTITY, + ecSignedPreKey, + Optional.empty(), + lastResortKey, + Optional.of(lastResortKey), + 409), + + // Last-resort key not found + Arguments.of( + AuthHelper.VALID_IDENTITY, + ecSignedPreKey, + Optional.of(ecSignedPreKey), + lastResortKey, + Optional.empty(), + 409), + + // Mismatched identity key + Arguments.of( + new IdentityKey(Curve.generateKeyPair().getPublicKey()), + ecSignedPreKey, + Optional.of(ecSignedPreKey), + lastResortKey, + Optional.of(lastResortKey), + 409), + + // Mismatched EC signed pre-key ID + Arguments.of( + AuthHelper.VALID_IDENTITY, + new ECSignedPreKey(ecSignedPreKey.keyId() + 1, ecSignedPreKey.publicKey(), ecSignedPreKey.signature()), + Optional.of(ecSignedPreKey), + lastResortKey, + Optional.of(lastResortKey), + 409), + + // Mismatched EC signed pre-key content + Arguments.of( + AuthHelper.VALID_IDENTITY, + KeysHelper.signedECPreKey(ecSignedPreKey.keyId(), AuthHelper.VALID_IDENTITY_KEY_PAIR), + Optional.of(ecSignedPreKey), + lastResortKey, + Optional.of(lastResortKey), + 409), + // Mismatched last-resort key ID + Arguments.of( + AuthHelper.VALID_IDENTITY, + ecSignedPreKey, + Optional.of(ecSignedPreKey), + new KEMSignedPreKey(lastResortKey.keyId() + 1, lastResortKey.publicKey(), lastResortKey.signature()), + Optional.of(lastResortKey), + 409), + + // Mismatched last-resort key content + Arguments.of( + AuthHelper.VALID_IDENTITY, + ecSignedPreKey, + Optional.of(ecSignedPreKey), + KeysHelper.signedKEMPreKey(lastResortKey.keyId(), AuthHelper.VALID_IDENTITY_KEY_PAIR), + Optional.of(lastResortKey), + 409) + ); + } + + private static byte[] getKeyDigest(final IdentityKey identityKey, final ECSignedPreKey ecSignedPreKey, final KEMSignedPreKey lastResortKey) + throws NoSuchAlgorithmException { + + final MessageDigest messageDigest = MessageDigest.getInstance("SHA-256"); + messageDigest.update(identityKey.serialize()); + + { + final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId()); + ecSignedPreKeyIdBuffer.flip(); + + messageDigest.update(ecSignedPreKeyIdBuffer); + messageDigest.update(ecSignedPreKey.serializedPublicKey()); + } + + { + final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + lastResortKeyIdBuffer.putLong(lastResortKey.keyId()); + lastResortKeyIdBuffer.flip(); + + messageDigest.update(lastResortKeyIdBuffer); + messageDigest.update(lastResortKey.serializedPublicKey()); + } + + return messageDigest.digest(); + } + + @Test + void checkKeysIncorrectDigestLength() { + try (final Response response = + resources.getJerseyTest() + .target("/v2/keys/check") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .post(Entity.entity(new CheckKeysRequest(IdentityType.ACI, new byte[31]), MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + + try (final Response response = + resources.getJerseyTest() + .target("/v2/keys/check") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .post(Entity.entity(new CheckKeysRequest(IdentityType.ACI, new byte[33]), MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + } + } }