Disallow identity key changes

This commit is contained in:
Jon Chambers 2023-11-30 10:44:26 -05:00 committed by Jon Chambers
parent 85383fe581
commit ede9297139
3 changed files with 56 additions and 136 deletions

View File

@ -4,12 +4,8 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.headers.Header;
@ -17,19 +13,15 @@ import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.PUT;
@ -41,8 +33,6 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
@ -54,12 +44,13 @@ import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem;
import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator;
import org.whispersystems.textsecuregcm.entities.SetKeysRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -76,14 +67,6 @@ public class KeysController {
private final AccountsManager accounts;
private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys");
private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange");
private static final String IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME = name(KeysController.class, "identityKeyChangeForbidden");
private static final String IDENTITY_TYPE_TAG_NAME = "identityType";
private static final String HAS_IDENTITY_KEY_TAG_NAME = "hasIdentityKey";
private static final Logger logger = LoggerFactory.getLogger(KeysController.class);
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) {
this.rateLimiters = rateLimiters;
this.keys = keys;
@ -112,80 +95,65 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Upload new prekeys",
description = """
Upload new prekeys for this device. Can also be used, from the primary device only, to set the account's identity
key, but this is deprecated now that accounts can be created atomically.
""")
@Operation(summary = "Upload new prekeys", description = "Upload new pre-keys for this device.")
@ApiResponse(responseCode = "200", description = "Indicates that new keys were successfully stored.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public CompletableFuture<Response> setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@RequestBody @NotNull @Valid final SetKeysRequest preKeys,
@RequestBody @NotNull @Valid final SetKeysRequest setKeysRequest,
@Parameter(allowEmptyValue=true)
@Schema(
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
boolean updateAccount = false;
final Device device = disabledPermittedAuth.getAuthenticatedDevice();
if (preKeys.signedPreKey() != null && !preKeys.signedPreKey().equals(device.getSignedPreKey(identityType))) {
updateAccount = true;
}
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType));
final IdentityKey oldIdentityKey = account.getIdentityKey(identityType);
if (!Objects.equals(preKeys.identityKey(), oldIdentityKey)) {
updateAccount = true;
final boolean hasIdentityKey = oldIdentityKey != null;
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))
.and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey))
.and(IDENTITY_TYPE_TAG_NAME, identityType.name());
if (!device.isPrimary()) {
Metrics.counter(IDENTITY_KEY_CHANGE_FORBIDDEN_COUNTER_NAME, tags).increment();
throw new ForbiddenException();
}
Metrics.counter(IDENTITY_KEY_CHANGE_COUNTER_NAME, tags).increment();
if (hasIdentityKey) {
logger.warn("Existing {} identity key changed; account age is {} days",
identityType,
Duration.between(Instant.ofEpochMilli(device.getCreated()), Instant.now()).toDays());
}
}
if (updateAccount) {
account = accounts.update(account, a -> {
if (preKeys.signedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(preKeys.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(preKeys.signedPreKey());
}
});
}
if (setKeysRequest.signedPreKey() != null &&
!setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) {
account = accounts.update(account, a -> a.getDevice(device.getId()).ifPresent(d -> {
switch (identityType) {
case ACI -> a.setIdentityKey(preKeys.identityKey());
case PNI -> a.setPhoneNumberIdentityKey(preKeys.identityKey());
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
});
}));
}
return keys.store(account.getIdentifier(identityType), device.getId(),
preKeys.preKeys(), preKeys.pqPreKeys(), preKeys.signedPreKey(), preKeys.pqLastResortPreKey())
setKeysRequest.preKeys(), setKeysRequest.pqPreKeys(), setKeysRequest.signedPreKey(), setKeysRequest.pqLastResortPreKey())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
private void checkSignedPreKeySignatures(final SetKeysRequest setKeysRequest, final IdentityKey identityKey) {
final List<SignedPreKey<?>> signedPreKeys = new ArrayList<>();
if (setKeysRequest.pqPreKeys() != null) {
signedPreKeys.addAll(setKeysRequest.pqPreKeys());
}
if (setKeysRequest.pqLastResortPreKey() != null) {
signedPreKeys.add(setKeysRequest.pqLastResortPreKey());
}
if (setKeysRequest.signedPreKey() != null) {
signedPreKeys.add(setKeysRequest.signedPreKey());
}
final boolean allSignaturesValid =
signedPreKeys.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, signedPreKeys);
if (!allSignaturesValid) {
throw new WebApplicationException("Invalid signature", 422);
}
}
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
@ -288,7 +256,7 @@ public class KeysController {
@ChangesDeviceEnabledState
@Operation(summary = "Upload a new signed prekey",
description = """
Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys with instead.
Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys instead.
""")
@ApiResponse(responseCode = "200", description = "Indicates that new prekey was successfully stored.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")

View File

@ -4,16 +4,9 @@
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull;
import java.util.ArrayList;
import java.util.List;
import javax.validation.Valid;
public record SetKeysRequest(
@Valid
@ -46,31 +39,5 @@ public record SetKeysRequest(
signed post-quantum last-resort prekey for the device; if absent, a stored last-resort prekey will *not* be
deleted. If present, must have a valid signature from the identity key in this request.
""")
KEMSignedPreKey pqLastResortPreKey,
@JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@NotNull
@Schema(description = """
Required. The public identity key for this identity (account or phone-number identity). If this device is not
the primary device for the account, must match the existing stored identity key for this identity.
""")
IdentityKey identityKey
) {
@AssertTrue
public boolean isSignatureValidOnEachSignedKey() {
List<SignedPreKey<?>> spks = new ArrayList<>();
if (pqPreKeys != null) {
spks.addAll(pqPreKeys);
}
if (pqLastResortPreKey != null) {
spks.add(pqLastResortPreKey);
}
if (signedPreKey != null) {
spks.add(signedPreKey);
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks);
}
KEMSignedPreKey pqLastResortPreKey) {
}

View File

@ -734,7 +734,9 @@ class KeysControllerTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey);
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey);
Response response =
resources.getJerseyTest()
@ -751,7 +753,6 @@ class KeysControllerTest {
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@ -766,7 +767,9 @@ class KeysControllerTest {
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SetKeysRequest setKeysRequest =
new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey, identityKey);
new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey);
Response response =
resources.getJerseyTest()
@ -785,7 +788,6 @@ class KeysControllerTest {
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@ -869,7 +871,9 @@ class KeysControllerTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey);
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.PNI)).thenReturn(identityKey);
Response response =
resources.getJerseyTest()
@ -887,7 +891,6 @@ class KeysControllerTest {
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@ -902,7 +905,9 @@ class KeysControllerTest {
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SetKeysRequest setKeysRequest =
new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey, identityKey);
new SetKeysRequest(List.of(preKey), signedPreKey, List.of(pqPreKey), pqLastResortPreKey);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.PNI)).thenReturn(identityKey);
Response response =
resources.getJerseyTest()
@ -922,7 +927,6 @@ class KeysControllerTest {
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@ -930,7 +934,7 @@ class KeysControllerTest {
@Test
void putPrekeyWithInvalidSignature() {
final ECSignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair());
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(), badSignedPreKey, null, null, IDENTITY_KEY);
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(), badSignedPreKey, null, null);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
@ -949,7 +953,9 @@ class KeysControllerTest {
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null, identityKey);
when(AuthHelper.DISABLED_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(identityKey);
final SetKeysRequest setKeysRequest = new SetKeysRequest(List.of(preKey), signedPreKey, null, null);
Response response =
resources.getJerseyTest()
@ -969,28 +975,7 @@ class KeysControllerTest {
assertThat(capturedList.get(0).keyId()).isEqualTo(31337);
assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey());
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any());
}
@Test
void putIdentityKeyNonPrimary() {
final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR);
final List<ECPreKey> preKeys = List.of(preKey);
final SetKeysRequest setKeysRequest = new SetKeysRequest(preKeys, signedPreKey, null, null, IDENTITY_KEY);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, SAMPLE_DEVICE_ID2,
AuthHelper.VALID_PASSWORD_3_LINKED))
.put(Entity.entity(setKeysRequest, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
}
}