Add an endpoint for checking that clients and the server have a common view of the client's repeated-use keys

This commit is contained in:
Jon Chambers 2024-02-22 20:05:41 -05:00 committed by Jon Chambers
parent 279f877bf2
commit d2716fe5cf
4 changed files with 295 additions and 3 deletions

View File

@ -16,6 +16,9 @@ import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@ -29,6 +32,7 @@ import javax.ws.rs.Consumes;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
@ -42,6 +46,7 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.CheckKeysRequest;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@ -221,6 +226,97 @@ public class KeysController {
}
}
@POST
@Path("/check")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Check keys", description = """
Checks that client and server have consistent views of repeated-use keys. For a given identity type, clients
submit a digest of their repeated-use key material. The digest is calculated as:
SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)
where the elements of the hash are:
- identityKeyBytes: the serialized form of the client's public identity key as produced by libsignal (i.e. one
version byte followed by 32 bytes of key material for a total of 33 bytes)
- signedEcPreKeyId: an 8-byte, big-endian representation of the ID of the client's signed EC pre-key
- signedEcPreKeyBytes: the serialized form of the client's signed EC pre-key as produced by libsignal (i.e. one
version byte followed by 32 bytes of key material for a total of 33 bytes)
- lastResortKeyId: an 8-byte, big-endian representation of the ID of the client's last-resort Kyber key
- lastResortKeyBytes: the serialized form of the client's last-resort Kyber key as produced by libsignal (i.e. one
version byte followed by 1568 bytes of key material for a total of 1569 bytes)
""")
@ApiResponse(responseCode = "200", description = "Indicates that client and server have consistent views of repeated-use keys")
@ApiResponse(responseCode = "401", description = "Account authentication check failed")
@ApiResponse(responseCode = "409", description = """
Indicates that client and server have inconsistent views of repeated-use keys or one or more repeated-use keys could
not be found
""")
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Response> setKeys(
@ReadOnly @Auth final AuthenticatedAccount auth,
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getAuthenticatedDevice().getId();
final CompletableFuture<Optional<ECSignedPreKey>> ecSignedPreKeyFuture =
keysManager.getEcSignedPreKey(identifier, deviceId);
final CompletableFuture<Optional<KEMSignedPreKey>> lastResortKeyFuture =
keysManager.getLastResort(identifier, deviceId);
return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture)
.thenApply(ignored -> {
final Optional<ECSignedPreKey> maybeSignedPreKey = ecSignedPreKeyFuture.join();
final Optional<KEMSignedPreKey> maybeLastResortKey = lastResortKeyFuture.join();
final boolean digestsMatch;
if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) {
final IdentityKey identityKey = auth.getAccount().getIdentityKey(checkKeysRequest.identityType());
final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get();
final KEMSignedPreKey lastResortKey = maybeLastResortKey.get();
final MessageDigest messageDigest;
try {
messageDigest = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e);
}
messageDigest.update(identityKey.serialize());
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest());
} else {
digestsMatch = false;
}
return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build();
});
}
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)

View File

@ -0,0 +1,34 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ExactlySize;
public record CheckKeysRequest(
@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
The identity type for which to check for a shared view of repeated-use keys
""")
IdentityType identityType,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
@ExactlySize(32)
@Schema(requiredMode = Schema.RequiredMode.REQUIRED, description = """
A 32-byte digest of the client's repeated-use keys for the given identity type. The digest is calculated as:
SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)
where the elements of the hash are:
- identityKeyBytes: the serialized form of the client's public identity key as produced by libsignal (i.e. one
version byte followed by 32 bytes of key material for a total of 33 bytes)
- signedEcPreKeyId: an 8-byte, big-endian representation of the ID of the client's signed EC pre-key
- signedEcPreKeyBytes: the serialized form of the client's signed EC pre-key as produced by libsignal (i.e. one
version byte followed by 32 bytes of key material for a total of 33 bytes)
- lastResortKeyId: an 8-byte, big-endian representation of the ID of the client's last-resort Kyber key
- lastResortKeyBytes: the serialized form of the client's last-resort Kyber key as produced by libsignal (i.e.
one version byte followed by 1568 bytes of key material for a total of 1569 bytes)
""")
byte[] digest) {
}

View File

@ -106,8 +106,7 @@ public class KeysManager {
.orElseGet(() -> pqLastResortKeys.find(identifier, deviceId)));
}
@VisibleForTesting
CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
public CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
return pqLastResortKeys.find(identifier, deviceId);
}

View File

@ -25,6 +25,9 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
@ -42,13 +45,15 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.CheckKeysRequest;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
@ -975,4 +980,162 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(422);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource
void checkKeys(
final IdentityKey clientIdentityKey,
final ECSignedPreKey clientEcSignedPreKey,
final Optional<ECSignedPreKey> serverEcSignedPreKey,
final KEMSignedPreKey clientLastResortKey,
final Optional<KEMSignedPreKey> serverLastResortKey,
final int expectedStatus) throws NoSuchAlgorithmException {
when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, Device.PRIMARY_ID))
.thenReturn(CompletableFuture.completedFuture(serverEcSignedPreKey));
when(KEYS.getLastResort(AuthHelper.VALID_UUID, Device.PRIMARY_ID))
.thenReturn(CompletableFuture.completedFuture(serverLastResortKey));
final CheckKeysRequest checkKeysRequest =
new CheckKeysRequest(IdentityType.ACI, getKeyDigest(clientIdentityKey, clientEcSignedPreKey, clientLastResortKey));
try (final Response response =
resources.getJerseyTest()
.target("/v2/keys/check")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.entity(checkKeysRequest, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(expectedStatus, response.getStatus());
}
}
private static List<Arguments> checkKeys() {
final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(17, AuthHelper.VALID_IDENTITY_KEY_PAIR);
final KEMSignedPreKey lastResortKey = KeysHelper.signedKEMPreKey(19, AuthHelper.VALID_IDENTITY_KEY_PAIR);
return List.of(
// All keys match
Arguments.of(
AuthHelper.VALID_IDENTITY,
ecSignedPreKey,
Optional.of(ecSignedPreKey),
lastResortKey,
Optional.of(lastResortKey),
200),
// Signed EC pre-key not found
Arguments.of(
AuthHelper.VALID_IDENTITY,
ecSignedPreKey,
Optional.empty(),
lastResortKey,
Optional.of(lastResortKey),
409),
// Last-resort key not found
Arguments.of(
AuthHelper.VALID_IDENTITY,
ecSignedPreKey,
Optional.of(ecSignedPreKey),
lastResortKey,
Optional.empty(),
409),
// Mismatched identity key
Arguments.of(
new IdentityKey(Curve.generateKeyPair().getPublicKey()),
ecSignedPreKey,
Optional.of(ecSignedPreKey),
lastResortKey,
Optional.of(lastResortKey),
409),
// Mismatched EC signed pre-key ID
Arguments.of(
AuthHelper.VALID_IDENTITY,
new ECSignedPreKey(ecSignedPreKey.keyId() + 1, ecSignedPreKey.publicKey(), ecSignedPreKey.signature()),
Optional.of(ecSignedPreKey),
lastResortKey,
Optional.of(lastResortKey),
409),
// Mismatched EC signed pre-key content
Arguments.of(
AuthHelper.VALID_IDENTITY,
KeysHelper.signedECPreKey(ecSignedPreKey.keyId(), AuthHelper.VALID_IDENTITY_KEY_PAIR),
Optional.of(ecSignedPreKey),
lastResortKey,
Optional.of(lastResortKey),
409),
// Mismatched last-resort key ID
Arguments.of(
AuthHelper.VALID_IDENTITY,
ecSignedPreKey,
Optional.of(ecSignedPreKey),
new KEMSignedPreKey(lastResortKey.keyId() + 1, lastResortKey.publicKey(), lastResortKey.signature()),
Optional.of(lastResortKey),
409),
// Mismatched last-resort key content
Arguments.of(
AuthHelper.VALID_IDENTITY,
ecSignedPreKey,
Optional.of(ecSignedPreKey),
KeysHelper.signedKEMPreKey(lastResortKey.keyId(), AuthHelper.VALID_IDENTITY_KEY_PAIR),
Optional.of(lastResortKey),
409)
);
}
private static byte[] getKeyDigest(final IdentityKey identityKey, final ECSignedPreKey ecSignedPreKey, final KEMSignedPreKey lastResortKey)
throws NoSuchAlgorithmException {
final MessageDigest messageDigest = MessageDigest.getInstance("SHA-256");
messageDigest.update(identityKey.serialize());
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
return messageDigest.digest();
}
@Test
void checkKeysIncorrectDigestLength() {
try (final Response response =
resources.getJerseyTest()
.target("/v2/keys/check")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.entity(new CheckKeysRequest(IdentityType.ACI, new byte[31]), MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
}
try (final Response response =
resources.getJerseyTest()
.target("/v2/keys/check")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.entity(new CheckKeysRequest(IdentityType.ACI, new byte[33]), MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(422, response.getStatus());
}
}
}