From ede929713920114fd3cf757ecdd7f72506c24df7 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 30 Nov 2023 10:44:26 -0500 Subject: [PATCH] Disallow identity key changes --- .../controllers/KeysController.java | 108 ++++++------------ .../entities/SetKeysRequest.java | 37 +----- .../controllers/KeysControllerTest.java | 47 +++----- 3 files changed, 56 insertions(+), 136 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 47f194370..ca312ff90 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -4,12 +4,8 @@ */ package org.whispersystems.textsecuregcm.controllers; -import static com.codahale.metrics.MetricRegistry.name; - import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tags; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.headers.Header; @@ -17,19 +13,15 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; -import java.time.Duration; -import java.time.Instant; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.Consumes; import javax.ws.rs.DefaultValue; -import javax.ws.rs.ForbiddenException; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; import javax.ws.rs.PUT; @@ -41,8 +33,6 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.signal.libsignal.protocol.IdentityKey; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; @@ -54,12 +44,13 @@ import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; +import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.entities.SetKeysRequest; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.experiment.Experiment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -76,14 +67,6 @@ public class KeysController { private final AccountsManager accounts; private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); - private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange"); - private static final String IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME = name(KeysController.class, "identityKeyChangeForbidden"); - - private static final String IDENTITY_TYPE_TAG_NAME = "identityType"; - private static final String HAS_IDENTITY_KEY_TAG_NAME = "hasIdentityKey"; - - private static final Logger logger = LoggerFactory.getLogger(KeysController.class); - public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) { this.rateLimiters = rateLimiters; this.keys = keys; @@ -112,80 +95,65 @@ public class KeysController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) @ChangesDeviceEnabledState - @Operation(summary = "Upload new prekeys", - description = """ - Upload new prekeys for this device. Can also be used, from the primary device only, to set the account's identity - key, but this is deprecated now that accounts can be created atomically. - """) + @Operation(summary = "Upload new prekeys", description = "Upload new pre-keys for this device.") @ApiResponse(responseCode = "200", description = "Indicates that new keys were successfully stored.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.") @ApiResponse(responseCode = "422", description = "Invalid request format.") public CompletableFuture setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth, - @RequestBody @NotNull @Valid final SetKeysRequest preKeys, + @RequestBody @NotNull @Valid final SetKeysRequest setKeysRequest, @Parameter(allowEmptyValue=true) @Schema( allowableValues={"aci", "pni"}, defaultValue="aci", description="whether this operation applies to the account (aci) or phone-number (pni) identity") - @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType, + @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { Account account = disabledPermittedAuth.getAccount(); - Device device = disabledPermittedAuth.getAuthenticatedDevice(); - boolean updateAccount = false; + final Device device = disabledPermittedAuth.getAuthenticatedDevice(); - if (preKeys.signedPreKey() != null && !preKeys.signedPreKey().equals(device.getSignedPreKey(identityType))) { - updateAccount = true; - } + checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType)); - final IdentityKey oldIdentityKey = account.getIdentityKey(identityType); - if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) { - updateAccount = true; - - final boolean hasIdentityKey = oldIdentityKey != null; - final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)) - .and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey)) - .and(IDENTITY_TYPE_TAG_NAME, identityType.name()); - - if (!device.isPrimary()) { - Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment(); - - throw new ForbiddenException(); - } - Metrics.counter(IDENTITY_KEY_CHANGE_COUNTER_NAME, tags).increment(); - - if (hasIdentityKey) { - logger.warn("Existing {} identity key changed; account age is {} days", - identityType, - Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays()); - } - } - - if (updateAccount) { - account = accounts.update(account, a -> { - if (preKeys.signedPreKey() != null) { - a.getDevice(device.getId()).ifPresent(d -> { - switch (identityType) { - case ACI -> d.setSignedPreKey(preKeys.signedPreKey()); - case PNI -> d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey()); - } - }); - } + if (setKeysRequest.signedPreKey() != null && + !setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) { + account = accounts.update(account, a -> a.getDevice(device.getId()).ifPresent(d -> { switch (identityType) { - case ACI -> a.setIdentityKey(preKeys.identityKey()); - case PNI -> a.setPhoneNumberIdentityKey(preKeys.identityKey()); + case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey()); + case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey()); } - }); + })); } return keys.store(account.getIdentifier(identityType), device.getId(), - preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey()) + setKeysRequest.preKeys(), setKeysRequest.pqPreKeys(), setKeysRequest.signedPreKey(), setKeysRequest.pqLastResortPreKey()) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } + private void checkSignedPreKeySignatures(final SetKeysRequest setKeysRequest, final IdentityKey identityKey) { + final List> signedPreKeys = new ArrayList<>(); + + if (setKeysRequest.pqPreKeys() != null) { + signedPreKeys.addAll(setKeysRequest.pqPreKeys()); + } + + if (setKeysRequest.pqLastResortPreKey() != null) { + signedPreKeys.add(setKeysRequest.pqLastResortPreKey()); + } + + if (setKeysRequest.signedPreKey() != null) { + signedPreKeys.add(setKeysRequest.signedPreKey()); + } + + final boolean allSignaturesValid = + signedPreKeys.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, signedPreKeys); + + if (!allSignaturesValid) { + throw new WebApplicationException("Invalid signature", 422); + } + } + @GET @Path("/{identifier}/{device_id}") @Produces(MediaType.APPLICATION_JSON) @@ -288,7 +256,7 @@ public class KeysController { @ChangesDeviceEnabledState @Operation(summary = "Upload a new signed prekey", description = """ - Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys with instead. + Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys instead. """) @ApiResponse(responseCode = "200", description = "Indicates that new prekey was successfully stored.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SetKeysRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SetKeysRequest.java index 272b63db1..e8a797c37 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SetKeysRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SetKeysRequest.java @@ -4,16 +4,9 @@ */ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; -import org.signal.libsignal.protocol.IdentityKey; -import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; -import javax.validation.Valid; -import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.NotNull; -import java.util.ArrayList; import java.util.List; +import javax.validation.Valid; public record SetKeysRequest( @Valid @@ -46,31 +39,5 @@ public record SetKeysRequest( signed post-quantum last-resort prekey for the device; if absent, a stored last-resort prekey will *not* be deleted. If present, must have a valid signature from the identity key in this request. """) - KEMSignedPreKey pqLastResortPreKey, - - @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) - @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) - @NotNull - @Schema(description = """ - Required. The public identity key for this identity (account or phone-number identity). If this device is not - the primary device for the account, must match the existing stored identity key for this identity. - """) - IdentityKey identityKey -) { - - @AssertTrue - public boolean isSignatureValidOnEachSignedKey() { - List> spks = new ArrayList<>(); - if (pqPreKeys != null) { - spks.addAll(pqPreKeys); - } - if (pqLastResortPreKey != null) { - spks.add(pqLastResortPreKey); - } - if (signedPreKey != null) { - spks.add(signedPreKey); - } - return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks); - } - + KEMSignedPreKey pqLastResortPreKey) { } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 699ab3a34..2cf0c4e58 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -734,7 +734,9 @@ class KeysControllerTest { final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey); + final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null); + + when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey); Response response = resources.getJerseyTest() @@ -751,7 +753,6 @@ class KeysControllerTest { assertThat(listCaptor.getValue()).containsExactly(preKey); - verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); } @@ -766,7 +767,9 @@ class KeysControllerTest { final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final SetKeysRequest setKeysRequest = - new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey, identityKey); + new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey); + + when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey); Response response = resources.getJerseyTest() @@ -785,7 +788,6 @@ class KeysControllerTest { assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); - verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); } @@ -869,7 +871,9 @@ class KeysControllerTest { final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey); + final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null); + + when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.PNI)).thenReturn(identityKey); Response response = resources.getJerseyTest() @@ -887,7 +891,6 @@ class KeysControllerTest { assertThat(listCaptor.getValue()).containsExactly(preKey); - verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); } @@ -902,7 +905,9 @@ class KeysControllerTest { final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final SetKeysRequest setKeysRequest = - new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey, identityKey); + new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey); + + when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.PNI)).thenReturn(identityKey); Response response = resources.getJerseyTest() @@ -922,7 +927,6 @@ class KeysControllerTest { assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); - verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); } @@ -930,7 +934,7 @@ class KeysControllerTest { @Test void putPrekeyWithInvalidSignature() { final ECSignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair()); - final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(), badSignedPreKey, null, null, IDENTITY_KEY); + final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(), badSignedPreKey, null, null); Response response = resources.getJerseyTest() .target("/v2/keys") @@ -949,7 +953,9 @@ class KeysControllerTest { final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey); + when(AuthHelper.DISABLED_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey); + + final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null); Response response = resources.getJerseyTest() @@ -969,28 +975,7 @@ class KeysControllerTest { assertThat(capturedList.get(0).keyId()).isEqualTo(31337); assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey()); - verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any()); } - - @Test - void putIdentityKeyNonPrimary() { - final ECPreKey preKey = KeysHelper.ecPreKey(31337); - final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR); - - final List preKeys = List.of(preKey); - - final SetKeysRequest setKeysRequest = new SetKeysRequest(preKeys, signedPreKey, null, null, IDENTITY_KEY); - - Response response = - resources.getJerseyTest() - .target("/v2/keys") - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, SAMPLE_DEVICE_ID2, - AuthHelper.VALID_PASSWORD_3_LINKED)) - .put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE)); - - assertThat(response.getStatus()).isEqualTo(403); - } }