diff --git a/integration-tests/src/main/java/org/signal/integration/Operations.java b/integration-tests/src/main/java/org/signal/integration/Operations.java index 140eac416..3ed873e2d 100644 --- a/integration-tests/src/main/java/org/signal/integration/Operations.java +++ b/integration-tests/src/main/java/org/signal/integration/Operations.java @@ -32,15 +32,18 @@ import org.signal.integration.config.Config; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.kem.KEMKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyType; +import org.signal.libsignal.protocol.kem.KEMPublicKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.RegistrationRequest; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.HeaderUtils; @@ -325,15 +328,15 @@ public final class Operations { } } - private static SignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) { - final byte[] pubKey = Curve.generateKeyPair().getPublicKey().serialize(); - final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); - return new SignedPreKey(id, pubKey, sig); + private static ECSignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) { + final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey(); + final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize()); + return new ECSignedPreKey(id, pubKey, sig); } - private static SignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) { - final byte[] pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey().serialize(); - final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); - return new SignedPreKey(id, pubKey, sig); + private static KEMSignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) { + final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey(); + final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize()); + return new KEMSignedPreKey(id, pubKey, sig); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 6de5d9993..ecf1d79c1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -41,7 +41,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; -import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; @@ -207,9 +209,9 @@ public class KeysController { for (Device device : devices) { UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; - SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); - PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); - SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; + ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); + ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); + KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { final int registrationId = usePhoneNumberIdentity ? @@ -234,7 +236,7 @@ public class KeysController { @Consumes(MediaType.APPLICATION_JSON) @ChangesDeviceEnabledState public void setSignedKey(@Auth final AuthenticatedAccount auth, - @Valid final SignedPreKey signedPreKey, + @Valid final ECSignedPreKey signedPreKey, @QueryParam("identity") final Optional identityType) { Device device = auth.getAuthenticatedDevice(); @@ -252,11 +254,11 @@ public class KeysController { @GET @Path("/signed") @Produces(MediaType.APPLICATION_JSON) - public Optional getSignedKey(@Auth final AuthenticatedAccount auth, + public Optional getSignedKey(@Auth final AuthenticatedAccount auth, @QueryParam("identity") final Optional identityType) { Device device = auth.getAuthenticatedDevice(); - SignedPreKey signedPreKey = usePhoneNumberIdentity(identityType) ? + ECSignedPreKey signedPreKey = usePhoneNumberIdentity(identityType) ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); return Optional.ofNullable(signedPreKey); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index 27d75e22e..74a394956 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -20,8 +20,6 @@ import javax.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record ChangeNumberRequest( @Schema(description=""" @@ -54,7 +52,7 @@ public record ChangeNumberRequest( @Schema(description=""" A new signed elliptic-curve prekey for each enabled device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @NotNull @Valid Map devicePniSignedPrekeys, + @NotNull @Valid Map devicePniSignedPrekeys, @Schema(description=""" A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. @@ -62,14 +60,14 @@ public record ChangeNumberRequest( If present, must contain one prekey per enabled device including this one. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + @Valid Map devicePniPqLastResortPrekeys, @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - List spks = new ArrayList<>(); + List> spks = new ArrayList<>(); if (devicePniSignedPrekeys != null) { spks.addAll(devicePniSignedPrekeys.values()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index 6a922728d..88df907ee 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -19,8 +19,6 @@ import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record ChangePhoneNumberRequest( @Schema(description="the new phone number for this account") @@ -46,7 +44,7 @@ public record ChangePhoneNumberRequest( @Schema(description=""" A new signed elliptic-curve prekey for each enabled device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Nullable Map devicePniSignedPrekeys, + @Nullable Map devicePniSignedPrekeys, @Schema(description=""" A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. @@ -54,14 +52,14 @@ public record ChangePhoneNumberRequest( If present, must contain one prekey per enabled device including this one. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Nullable @Valid Map devicePniPqLastResortPrekeys, + @Nullable @Valid Map devicePniPqLastResortPrekeys, @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") @Nullable Map pniRegistrationIds) { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - List spks = new ArrayList<>(); + List> spks = new ArrayList<>(); if (devicePniSignedPrekeys != null) { spks.addAll(devicePniSignedPrekeys.values()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceActivationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceActivationRequest.java index feddc9354..57be1f6b7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceActivationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceActivationRequest.java @@ -4,9 +4,6 @@ import io.swagger.v3.oas.annotations.media.Schema; import javax.validation.Valid; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; - import java.util.Optional; public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @@ -14,28 +11,28 @@ public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, + Optional<@Valid ECSignedPreKey> aciSignedPreKey, @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ A signed EC pre-key to be associated with this account's PNI. If provided, an account will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, + Optional<@Valid ECSignedPreKey> pniSignedPreKey, @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ A signed Kyber-1024 "last resort" pre-key to be associated with this account's ACI. If provided, an account will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciPqLastResortPreKey, + Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ A signed Kyber-1024 "last resort" pre-key to be associated with this account's PNI. If provided, an account will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniPqLastResortPreKey, + Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey, @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ An APNs token set for the account's primary device. If provided, the account's primary diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java new file mode 100644 index 000000000..7f244945c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECPreKey.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter; + +public record ECPreKey(long keyId, + @JsonSerialize(using = ECPublicKeyAdapter.Serializer.class) + @JsonDeserialize(using = ECPublicKeyAdapter.Deserializer.class) + ECPublicKey publicKey) implements PreKey { + + @Override + public byte[] serializedPublicKey() { + return publicKey().serialize(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java new file mode 100644 index 000000000..b8dac36bb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ECSignedPreKey.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter; +import java.util.Arrays; +import java.util.Objects; + +public record ECSignedPreKey(long keyId, + + @JsonSerialize(using = ECPublicKeyAdapter.Serializer.class) + @JsonDeserialize(using = ECPublicKeyAdapter.Deserializer.class) + ECPublicKey publicKey, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] signature) implements SignedPreKey { + + @Override + public byte[] serializedPublicKey() { + return publicKey().serialize(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ECSignedPreKey that = (ECSignedPreKey) o; + return keyId == that.keyId && publicKey.equals(that.publicKey) && Arrays.equals(signature, that.signature); + } + + @Override + public int hashCode() { + int result = Objects.hash(keyId, publicKey); + result = 31 * result + Arrays.hashCode(signature); + return result; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java new file mode 100644 index 000000000..88e83c978 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/KEMSignedPreKey.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.signal.libsignal.protocol.kem.KEMPublicKey; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.whispersystems.textsecuregcm.util.KEMPublicKeyAdapter; +import java.util.Arrays; +import java.util.Objects; + +public record KEMSignedPreKey(long keyId, + + @JsonSerialize(using = KEMPublicKeyAdapter.Serializer.class) + @JsonDeserialize(using = KEMPublicKeyAdapter.Deserializer.class) + KEMPublicKey publicKey, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] signature) implements SignedPreKey { + + @Override + public byte[] serializedPublicKey() { + return publicKey().serialize(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + KEMSignedPreKey that = (KEMSignedPreKey) o; + return keyId == that.keyId && publicKey.equals(that.publicKey) && Arrays.equals(signature, that.signature); + } + + @Override + public int hashCode() { + int result = Objects.hash(keyId, publicKey); + result = 31 * result + Arrays.hashCode(signature); + return result; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java index 6dab8a238..b12fb36a5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/LinkDeviceRequest.java @@ -25,10 +25,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, @JsonProperty("accountAttributes") AccountAttributes accountAttributes, - @JsonProperty("aciSignedPreKey") Optional<@Valid SignedPreKey> aciSignedPreKey, - @JsonProperty("pniSignedPreKey") Optional<@Valid SignedPreKey> pniSignedPreKey, - @JsonProperty("aciPqLastResortPreKey") Optional<@Valid SignedPreKey> aciPqLastResortPreKey, - @JsonProperty("pniPqLastResortPreKey") Optional<@Valid SignedPreKey> pniPqLastResortPreKey, + @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, + @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, + @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, + @JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index d02ff8d2b..e564f93b1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -15,8 +15,6 @@ import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record PhoneNumberIdentityKeyDistributionRequest( @NotNull @@ -37,7 +35,7 @@ public record PhoneNumberIdentityKeyDistributionRequest( @Schema(description=""" A new signed elliptic-curve prekey for each enabled device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - Map devicePniSignedPrekeys, + Map devicePniSignedPrekeys, @Schema(description=""" A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. @@ -45,7 +43,7 @@ public record PhoneNumberIdentityKeyDistributionRequest( If present, must contain one prekey per enabled device including this one. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + @Valid Map devicePniPqLastResortPrekeys, @NotNull @Valid @@ -54,7 +52,7 @@ public record PhoneNumberIdentityKeyDistributionRequest( @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - List spks = new ArrayList<>(devicePniSignedPrekeys.values()); + List> spks = new ArrayList<>(devicePniSignedPrekeys.values()); if (devicePniPqLastResortPrekeys != null) { spks.addAll(devicePniPqLastResortPrekeys.values()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java index 40d5e50a7..0ebe89cd4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java @@ -1,68 +1,15 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2023 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +public interface PreKey { -import javax.validation.constraints.NotEmpty; -import javax.validation.constraints.NotNull; -import java.util.Arrays; -import java.util.Objects; + long keyId(); -public class PreKey { + K publicKey(); - @JsonProperty - @NotNull - private long keyId; - - @JsonProperty - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - @NotEmpty - private byte[] publicKey; - - public PreKey() {} - - public PreKey(long keyId, byte[] publicKey) - { - this.keyId = keyId; - this.publicKey = publicKey; - } - - public byte[] getPublicKey() { - return publicKey; - } - - public void setPublicKey(byte[] publicKey) { - this.publicKey = publicKey; - } - - public long getKeyId() { - return keyId; - } - - public void setKeyId(long keyId) { - this.keyId = keyId; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - PreKey preKey = (PreKey) o; - return keyId == preKey.keyId && Arrays.equals(publicKey, preKey.publicKey); - } - - @Override - public int hashCode() { - int result = Objects.hash(keyId); - result = 31 * result + Arrays.hashCode(publicKey); - return result; - } + byte[] serializedPublicKey(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java index 42491b09f..0cf519c9f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java @@ -20,20 +20,20 @@ public class PreKeyResponseItem { @JsonProperty @Schema(description="the signed elliptic-curve prekey for the device, if one has been set") - private SignedPreKey signedPreKey; + private ECSignedPreKey signedPreKey; @JsonProperty @Schema(description="an unsigned elliptic-curve prekey for the device, if any remain") - private PreKey preKey; + private ECPreKey preKey; @JsonProperty @Schema(description="a signed post-quantum prekey for the device " + "(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)") - private SignedPreKey pqPreKey; + private KEMSignedPreKey pqPreKey; public PreKeyResponseItem() {} - public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) { + public PreKeyResponseItem(long deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey, KEMSignedPreKey pqPreKey) { this.deviceId = deviceId; this.registrationId = registrationId; this.signedPreKey = signedPreKey; @@ -42,17 +42,17 @@ public class PreKeyResponseItem { } @VisibleForTesting - public SignedPreKey getSignedPreKey() { + public ECSignedPreKey getSignedPreKey() { return signedPreKey; } @VisibleForTesting - public PreKey getPreKey() { + public ECPreKey getPreKey() { return preKey; } @VisibleForTesting - public SignedPreKey getPqPreKey() { + public KEMSignedPreKey getPqPreKey() { return pqPreKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java index 0eaebe835..67a712426 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java @@ -15,19 +15,13 @@ public abstract class PreKeySignatureValidator { public static final Counter INVALID_SIGNATURE_COUNTER = Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")); - public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection spks) { - try { - final boolean success = spks.stream() - .allMatch(spk -> identityKey.getPublicKey().verifySignature(spk.getPublicKey(), spk.getSignature())); + public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection> spks) { + final boolean success = spks.stream().allMatch(spk -> spk.signatureValid(identityKey)); - if (!success) { - INVALID_SIGNATURE_COUNTER.increment(); - } - - return success; - } catch (final IllegalArgumentException e) { + if (!success) { INVALID_SIGNATURE_COUNTER.increment(); - return false; } + + return success; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java index d1479bd06..8b17d4da7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java @@ -16,8 +16,6 @@ import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public class PreKeyState { @@ -26,16 +24,15 @@ public class PreKeyState { @Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " + "If present and not empty, replaces all stored unsigned EC prekeys for the device; " + "if absent or empty, any stored unsigned EC prekeys for the device are not deleted.") - private List<@ValidPreKey(type=PreKeyType.ECC) PreKey> preKeys; + private List<@Valid ECPreKey> preKeys; @JsonProperty @Valid - @ValidPreKey(type=PreKeyType.ECC) @Schema(description="An optional signed elliptic-curve prekey to use for this device. " + "If present, replaces the stored signed elliptic-curve prekey for the device; " + "if absent, the stored signed prekey is not deleted. " + "If present, must have a valid signature from the identity key in this request.") - private SignedPreKey signedPreKey; + private ECSignedPreKey signedPreKey; @JsonProperty @Valid @@ -43,16 +40,15 @@ public class PreKeyState { "Each key must have a valid signature from the identity key in this request. " + "If present and not empty, replaces all stored unsigned PQ prekeys for the device; " + "if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.") - private List<@ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> pqPreKeys; + private List<@Valid KEMSignedPreKey> pqPreKeys; @JsonProperty @Valid - @ValidPreKey(type=PreKeyType.KYBER) @Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " + "If present, replaces the stored signed post-quantum last-resort prekey for the device; " + "if absent, a stored last-resort prekey will *not* be deleted. " + "If present, must have a valid signature from the identity key in this request.") - private SignedPreKey pqLastResortPreKey; + private KEMSignedPreKey pqLastResortPreKey; @JsonProperty @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) @@ -67,12 +63,12 @@ public class PreKeyState { public PreKeyState() {} @VisibleForTesting - public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List keys) { + public PreKeyState(IdentityKey identityKey, ECSignedPreKey signedPreKey, List keys) { this(identityKey, signedPreKey, keys, null, null); } @VisibleForTesting - public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List keys, List pqKeys, SignedPreKey pqLastResortKey) { + public PreKeyState(IdentityKey identityKey, ECSignedPreKey signedPreKey, List keys, List pqKeys, KEMSignedPreKey pqLastResortKey) { this.identityKey = identityKey; this.signedPreKey = signedPreKey; this.preKeys = keys; @@ -80,19 +76,19 @@ public class PreKeyState { this.pqLastResortPreKey = pqLastResortKey; } - public List getPreKeys() { + public List getPreKeys() { return preKeys; } - public SignedPreKey getSignedPreKey() { + public ECSignedPreKey getSignedPreKey() { return signedPreKey; } - public List getPqPreKeys() { + public List getPqPreKeys() { return pqPreKeys; } - public SignedPreKey getPqLastResortPreKey() { + public KEMSignedPreKey getPqLastResortPreKey() { return pqLastResortPreKey; } @@ -102,7 +98,7 @@ public class PreKeyState { @AssertTrue public boolean isSignatureValidOnEachSignedKey() { - List spks = new ArrayList<>(); + List> spks = new ArrayList<>(); if (pqPreKeys != null) { spks.addAll(pqPreKeys); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java index bd21f5367..19300d686 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -20,8 +20,6 @@ import javax.validation.constraints.NotNull; import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ The ID of an existing verification session as it appears in a verification session @@ -82,10 +80,10 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, @JsonProperty("aciIdentityKey") Optional aciIdentityKey, @JsonProperty("pniIdentityKey") Optional pniIdentityKey, - @JsonProperty("aciSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, - @JsonProperty("pniSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, - @JsonProperty("aciPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> aciPqLastResortPreKey, - @JsonProperty("pniPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> pniPqLastResortPreKey, + @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, + @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey, + @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey, + @JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { @@ -106,7 +104,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private static boolean validatePreKeySignature(final Optional maybeIdentityKey, - final Optional maybeSignedPreKey) { + final Optional> maybeSignedPreKey) { return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey .map(identityKey -> PreKeySignatureValidator.validatePreKeySignatures(identityKey, List.of(signedPreKey))) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java index 21b5c3d5d..b4075f440 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java @@ -1,50 +1,17 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2023 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.entities; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.signal.libsignal.protocol.IdentityKey; -import javax.validation.constraints.NotEmpty; -import java.util.Arrays; +public interface SignedPreKey extends PreKey { -public class SignedPreKey extends PreKey { + byte[] signature(); - @JsonProperty - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - @NotEmpty - private byte[] signature; - - public SignedPreKey() {} - - public SignedPreKey(long keyId, byte[] publicKey, byte[] signature) { - super(keyId, publicKey); - this.signature = signature; - } - - public byte[] getSignature() { - return signature; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - if (!super.equals(o)) return false; - SignedPreKey that = (SignedPreKey) o; - return Arrays.equals(signature, that.signature); - } - - @Override - public int hashCode() { - int result = super.hashCode(); - result = 31 * result + Arrays.hashCode(signature); - return result; + default boolean signatureValid(final IdentityKey identityKey) { + return identityKey.getPublicKey().verifySignature(serializedPublicKey(), signature()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 8e7f08a17..b51078b2b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -45,7 +45,8 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; @@ -258,8 +259,8 @@ public class AccountsManager { public Account changeNumber(final Account account, final String targetNumber, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { final String originalNumber = account.getNumber(); @@ -350,8 +351,8 @@ public class AccountsManager { public Account updatePniKeys(final Account account, final IdentityKey pniIdentityKey, - final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, + final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, final Map pniRegistrationIds) throws MismatchedDevicesException { validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); @@ -369,7 +370,7 @@ public class AccountsManager { private boolean setPniKeys(final Account account, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniSignedPreKeys, @Nullable final Map pniRegistrationIds) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { return false; @@ -383,7 +384,7 @@ public class AccountsManager { if (!device.isEnabled()) { continue; } - SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId()); + ECSignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId()); int registrationId = pniRegistrationIds.get(device.getId()); changed = changed || !signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) || @@ -398,8 +399,8 @@ public class AccountsManager { } private void validateDevices(final Account account, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, @Nullable final Map pniRegistrationIds) throws MismatchedDevicesException { if (pniSignedPreKeys == null && pniRegistrationIds == null) { return; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index e754a5fe3..719f9023a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -20,9 +20,10 @@ import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; @@ -41,8 +42,8 @@ public class ChangeNumberManager { public Account changeNumber(final Account account, final String number, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map deviceSignedPreKeys, - @Nullable final Map devicePqLastResortPreKeys, + @Nullable final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, @Nullable final List deviceMessages, @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException, StaleDevicesException { @@ -81,8 +82,8 @@ public class ChangeNumberManager { public Account updatePniKeys(final Account account, final IdentityKey pniIdentityKey, - final Map deviceSignedPreKeys, - @Nullable final Map devicePqLastResortPreKeys, + final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, final List deviceMessages, final Map pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { validateDeviceMessages(account, deviceMessages); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index e6d2197fe..56cdcfe3a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -13,7 +13,7 @@ import java.util.stream.Collectors; import java.util.stream.LongStream; import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.util.Util; public class Device { @@ -61,10 +61,10 @@ public class Device { private Integer phoneNumberIdentityRegistrationId; @JsonProperty - private SignedPreKey signedPreKey; + private ECSignedPreKey signedPreKey; @JsonProperty("pniSignedPreKey") - private SignedPreKey phoneNumberIdentitySignedPreKey; + private ECSignedPreKey phoneNumberIdentitySignedPreKey; @JsonProperty private long lastSeen; @@ -230,19 +230,19 @@ public class Device { this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; } - public SignedPreKey getSignedPreKey() { + public ECSignedPreKey getSignedPreKey() { return signedPreKey; } - public void setSignedPreKey(SignedPreKey signedPreKey) { + public void setSignedPreKey(ECSignedPreKey signedPreKey) { this.signedPreKey = signedPreKey; } - public SignedPreKey getPhoneNumberIdentitySignedPreKey() { + public ECSignedPreKey getPhoneNumberIdentitySignedPreKey() { return phoneNumberIdentitySignedPreKey; } - public void setPhoneNumberIdentitySignedPreKey(final SignedPreKey phoneNumberIdentitySignedPreKey) { + public void setPhoneNumberIdentitySignedPreKey(final ECSignedPreKey phoneNumberIdentitySignedPreKey) { this.phoneNumberIdentitySignedPreKey = phoneNumberIdentitySignedPreKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index eac1dc5f7..a78bedf2d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -13,15 +13,15 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import javax.annotation.Nullable; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; public class KeysManager { private final SingleUseECPreKeyStore ecPreKeys; private final SingleUseKEMPreKeyStore pqPreKeys; - private final RepeatedUseSignedPreKeyStore pqLastResortKeys; + private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys; public KeysManager( final DynamoDbAsyncClient dynamoDbAsyncClient, @@ -30,18 +30,18 @@ public class KeysManager { final String pqLastResortTableName) { this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName); this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName); - this.pqLastResortKeys = new RepeatedUseSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName); + this.pqLastResortKeys = new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName); } - public void store(final UUID identifier, final long deviceId, final List keys) { + public void store(final UUID identifier, final long deviceId, final List keys) { store(identifier, deviceId, keys, null, null); } public void store( final UUID identifier, final long deviceId, - @Nullable final List ecKeys, - @Nullable final List pqKeys, - @Nullable final SignedPreKey pqLastResortKey) { + @Nullable final List ecKeys, + @Nullable final List pqKeys, + @Nullable final KEMSignedPreKey pqLastResortKey) { final List> storeFutures = new ArrayList<>(); @@ -60,15 +60,15 @@ public class KeysManager { CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join(); } - public void storePqLastResort(final UUID identifier, final Map keys) { + public void storePqLastResort(final UUID identifier, final Map keys) { pqLastResortKeys.store(identifier, keys).join(); } - public Optional takeEC(final UUID identifier, final long deviceId) { + public Optional takeEC(final UUID identifier, final long deviceId) { return ecPreKeys.take(identifier, deviceId).join(); } - public Optional takePQ(final UUID identifier, final long deviceId) { + public Optional takePQ(final UUID identifier, final long deviceId) { return pqPreKeys.take(identifier, deviceId) .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) @@ -76,9 +76,8 @@ public class KeysManager { } @VisibleForTesting - Optional getLastResort(final UUID identifier, final long deviceId) { - return pqLastResortKeys.find(identifier, deviceId).join() - .map(signedPreKey -> signedPreKey); + Optional getLastResort(final UUID identifier, final long deviceId) { + return pqLastResortKeys.find(identifier, deviceId).join(); } public List getPqEnabledDevices(final UUID identifier) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java new file mode 100644 index 000000000..e6720a213 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.signal.libsignal.protocol.InvalidKeyException; +import org.signal.libsignal.protocol.kem.KEMPublicKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import java.util.Map; +import java.util.UUID; + +public class RepeatedUseKEMSignedPreKeyStore extends RepeatedUseSignedPreKeyStore { + + public RepeatedUseKEMSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { + super(dynamoDbAsyncClient, tableName); + } + + @Override + protected Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final KEMSignedPreKey signedPreKey) { + + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), + KEY_DEVICE_ID, getSortKey(deviceId), + ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.keyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()), + ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature())); + } + + @Override + protected KEMSignedPreKey getPreKeyFromItem(final Map item) { + try { + return new KEMSignedPreKey( + Long.parseLong(item.get(ATTR_KEY_ID).n()), + new KEMPublicKey(item.get(ATTR_PUBLIC_KEY).b().asByteArray()), + item.get(ATTR_SIGNATURE).b().asByteArray()); + } catch (final InvalidKeyException e) { + // This should never happen since we're serializing keys directly from `KEMPublicKey` instances on the way in + throw new IllegalArgumentException(e); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index eea257dea..06641568d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -5,15 +5,12 @@ package org.whispersystems.textsecuregcm.storage; -import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import org.signal.libsignal.protocol.InvalidKeyException; -import org.signal.libsignal.protocol.kem.KEMPublicKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -37,7 +34,7 @@ import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; * Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed * provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key. */ -public class RepeatedUseSignedPreKeyStore { +public abstract class RepeatedUseSignedPreKeyStore> { private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; @@ -63,9 +60,6 @@ public class RepeatedUseSignedPreKeyStore { private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey"); private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; - private static final Counter INVALID_KEY_COUNTER = - Metrics.counter(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "invalidKey")); - public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { this.dynamoDbAsyncClient = dynamoDbAsyncClient; this.tableName = tableName; @@ -81,7 +75,7 @@ public class RepeatedUseSignedPreKeyStore { * * @return a future that completes once the key has been stored */ - public CompletableFuture store(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { + public CompletableFuture store(final UUID identifier, final long deviceId, final K signedPreKey) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.putItem(PutItemRequest.builder() @@ -101,14 +95,14 @@ public class RepeatedUseSignedPreKeyStore { * * @return a future that completes once all keys have been stored */ - public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { + public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder() .transactItems(signedPreKeysByDeviceId.entrySet().stream() .map(entry -> { final long deviceId = entry.getKey(); - final SignedPreKey signedPreKey = entry.getValue(); + final K signedPreKey = entry.getValue(); return TransactWriteItem.builder() .put(Put.builder() @@ -131,10 +125,10 @@ public class RepeatedUseSignedPreKeyStore { * @return a future that yields an optional signed pre-key if one is available for the target device or empty if no * key could be found for the target device */ - public CompletableFuture> find(final UUID identifier, final long deviceId) { + public CompletableFuture> find(final UUID identifier, final long deviceId) { final Timer.Sample sample = Timer.start(); - final CompletableFuture> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder() + final CompletableFuture> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder() .tableName(tableName) .key(getPrimaryKey(identifier, deviceId)) .consistentRead(true) @@ -202,41 +196,21 @@ public class RepeatedUseSignedPreKeyStore { .map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n())); } - private static Map getPrimaryKey(final UUID identifier, final long deviceId) { + protected static Map getPrimaryKey(final UUID identifier, final long deviceId) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID, getSortKey(deviceId)); } - private static AttributeValue getPartitionKey(final UUID accountUuid) { + protected static AttributeValue getPartitionKey(final UUID accountUuid) { return AttributeValues.fromUUID(accountUuid); } - private static AttributeValue getSortKey(final long deviceId) { + protected static AttributeValue getSortKey(final long deviceId) { return AttributeValues.fromLong(deviceId); } - private static Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final SignedPreKey signedPreKey) { - return Map.of( - KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), - KEY_DEVICE_ID, getSortKey(deviceId), - ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.getKeyId()), - ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()), - ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature())); - } + protected abstract Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final K signedPreKey); - private static SignedPreKey getPreKeyFromItem(final Map item) { - final byte[] publicKeyBytes = item.get(ATTR_PUBLIC_KEY).b().asByteArray(); - - try { - new KEMPublicKey(publicKeyBytes); - } catch (final InvalidKeyException e) { - INVALID_KEY_COUNTER.increment(); - } - - return new SignedPreKey( - Long.parseLong(item.get(ATTR_KEY_ID).n()), - publicKeyBytes, - item.get(ATTR_SIGNATURE).b().asByteArray()); - } + protected abstract K getPreKeyFromItem(final Map item); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java index 6f7f91ae7..9296aabb9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -5,46 +5,39 @@ package org.whispersystems.textsecuregcm.storage; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import java.util.Map; import java.util.UUID; -public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { - - private static final Counter INVALID_KEY_COUNTER = - Metrics.counter(MetricsUtil.name(SingleUseECPreKeyStore.class, "invalidKey")); +public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { super(dynamoDbAsyncClient, tableName); } @Override - protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final PreKey preKey) { + protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final ECPreKey preKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), - KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), - ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey())); + KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey())); } @Override - protected PreKey getPreKeyFromItem(final Map item) { + protected ECPreKey getPreKeyFromItem(final Map item) { final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); try { - new ECPublicKey(publicKey); + return new ECPreKey(keyId, new ECPublicKey(publicKey)); } catch (final InvalidKeyException e) { - INVALID_KEY_COUNTER.increment(); + // This should never happen since we're serializing keys directly from `ECPublicKey` instances on the way in + throw new IllegalArgumentException(e); } - - return new PreKey(keyId, publicKey); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java index ba41f0629..2e54fad37 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java @@ -5,48 +5,41 @@ package org.whispersystems.textsecuregcm.storage; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.kem.KEMPublicKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import java.util.Map; import java.util.UUID; -public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore { - - private static final Counter INVALID_KEY_COUNTER = - Metrics.counter(MetricsUtil.name(SingleUseKEMPreKeyStore.class, "invalidKey")); +public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore { protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { super(dynamoDbAsyncClient, tableName); } @Override - protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { + protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final KEMSignedPreKey signedPreKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), - KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.getKeyId()), - ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()), - ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature())); + KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()), + ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature())); } @Override - protected SignedPreKey getPreKeyFromItem(final Map item) { + protected KEMSignedPreKey getPreKeyFromItem(final Map item) { final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); - final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); - final byte[] signature = extractByteArray(item.get(ATTR_SIGNATURE)); + final byte[] publicKey = item.get(ATTR_PUBLIC_KEY).b().asByteArray(); + final byte[] signature = item.get(ATTR_SIGNATURE).b().asByteArray(); try { - new KEMPublicKey(publicKey); + return new KEMSignedPreKey(keyId, new KEMPublicKey(publicKey), signature); } catch (final InvalidKeyException e) { - INVALID_KEY_COUNTER.increment(); + // This should never happen since we're serializing keys directly from `KEMPublicKey` instances on the way in + throw new IllegalArgumentException(e); } - - return new SignedPreKey(keyId, publicKey, signature); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java index 1a5864f38..2b87e5b41 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -47,7 +47,7 @@ import software.amazon.awssdk.services.dynamodb.model.Select; * the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party * may fall back to using the device's repeated-use ("last-resort") signed pre-key instead. */ -public abstract class SingleUsePreKeyStore { +public abstract class SingleUsePreKeyStore> { private final DynamoDbAsyncClient dynamoDbAsyncClient; private final String tableName; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ECPublicKeyAdapter.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ECPublicKeyAdapter.java new file mode 100644 index 000000000..cd25cdb65 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ECPublicKeyAdapter.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.util.Base64; +import org.signal.libsignal.protocol.InvalidKeyException; +import org.signal.libsignal.protocol.ecc.ECPublicKey; + +public class ECPublicKeyAdapter { + + public static class Serializer extends JsonSerializer { + + @Override + public void serialize(final ECPublicKey ecPublicKey, + final JsonGenerator jsonGenerator, + final SerializerProvider serializers) throws IOException { + + jsonGenerator.writeString(Base64.getEncoder().encodeToString(ecPublicKey.serialize())); + } + } + + public static class Deserializer extends JsonDeserializer { + + @Override + public ECPublicKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException { + final byte[] ecPublicKeyBytes; + + try { + ecPublicKeyBytes = Base64.getDecoder().decode(parser.getValueAsString()); + } catch (final IllegalArgumentException e) { + throw new JsonParseException(parser, "Could not parse EC public key as a base64-encoded value", e); + } + + try { + return new ECPublicKey(ecPublicKeyBytes); + } catch (final InvalidKeyException e) { + throw new JsonParseException(parser, "Could not interpret key bytes as an EC public key", e); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/KEMPublicKeyAdapter.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/KEMPublicKeyAdapter.java new file mode 100644 index 000000000..bbbf7d62c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/KEMPublicKeyAdapter.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import org.signal.libsignal.protocol.InvalidKeyException; +import org.signal.libsignal.protocol.kem.KEMPublicKey; +import java.io.IOException; +import java.util.Base64; + +public class KEMPublicKeyAdapter { + + public static class Serializer extends JsonSerializer { + + @Override + public void serialize(final KEMPublicKey kemPublicKey, + final JsonGenerator jsonGenerator, + final SerializerProvider serializers) throws IOException { + + jsonGenerator.writeString(Base64.getEncoder().encodeToString(kemPublicKey.serialize())); + } + } + + public static class Deserializer extends JsonDeserializer { + + @Override + public KEMPublicKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException { + final byte[] kemPublicKeyBytes; + + try { + kemPublicKeyBytes = Base64.getDecoder().decode(parser.getValueAsString()); + } catch (final IllegalArgumentException e) { + throw new JsonParseException(parser, "Could not parse KEM public key as a base64-encoded value", e); + } + + try { + return new KEMPublicKey(kemPublicKeyBytes); + } catch (final InvalidKeyException e) { + throw new JsonParseException(parser, "Could not interpret key bytes as a KEM public key", e); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKey.java deleted file mode 100644 index 9ce158ff4..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKey.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2023 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.util; - -import static java.lang.annotation.ElementType.FIELD; -import static java.lang.annotation.ElementType.PARAMETER; -import static java.lang.annotation.ElementType.TYPE_USE; -import static java.lang.annotation.RetentionPolicy.RUNTIME; -import java.lang.annotation.Documented; -import java.lang.annotation.Retention; -import java.lang.annotation.Target; -import javax.validation.Constraint; -import javax.validation.Payload; - -@Target({FIELD, PARAMETER, TYPE_USE}) -@Retention(RUNTIME) -@Constraint(validatedBy = {ValidPreKeyValidator.class}) -@Documented -public @interface ValidPreKey { - - public enum PreKeyType { - ECC, - KYBER - } - - PreKeyType type(); - - String message() default "{org.whispersystems.textsecuregcm.util.ValidPreKey.message}"; - - Class[] groups() default { }; - - Class[] payload() default { }; - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKeyValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKeyValidator.java deleted file mode 100644 index 83eef9a98..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/ValidPreKeyValidator.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.util; - -import javax.validation.ConstraintValidator; -import javax.validation.ConstraintValidatorContext; -import org.signal.libsignal.protocol.InvalidKeyException; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.kem.KEMPublicKey; -import org.whispersystems.textsecuregcm.entities.PreKey; - -public class ValidPreKeyValidator implements ConstraintValidator { - private ValidPreKey.PreKeyType type; - - @Override - public void initialize(ValidPreKey annotation) { - type = annotation.type(); - } - - @Override - public boolean isValid(PreKey value, ConstraintValidatorContext context) { - if (value == null) { - return true; - } - try { - switch (type) { - case ECC -> Curve.decodePoint(value.getPublicKey(), 0); - case KYBER -> new KEMPublicKey(value.getPublicKey()); - } - } catch (IllegalArgumentException | InvalidKeyException e) { - return false; - } - return true; - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 8bef5606d..d0018d1e8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -74,6 +74,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -84,7 +85,6 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiter; @@ -194,7 +194,7 @@ class MessageControllerTest { when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); } - private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { + private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) { final Device device = new Device(); device.setId(id); device.setRegistrationId(registrationId); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index 8703e69f4..a2c3c85c7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -55,10 +55,11 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockError; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; @@ -418,10 +419,10 @@ class RegistrationControllerTest { static Stream atomicAccountCreationConflictingChannel() { final Optional aciIdentityKey; final Optional pniIdentityKey; - final Optional aciSignedPreKey; - final Optional pniSignedPreKey; - final Optional aciPqLastResortPreKey; - final Optional pniPqLastResortPreKey; + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); @@ -507,10 +508,10 @@ class RegistrationControllerTest { static Stream atomicAccountCreationPartialSignedPreKeys() { final Optional aciIdentityKey; final Optional pniIdentityKey; - final Optional aciSignedPreKey; - final Optional pniSignedPreKey; - final Optional aciPqLastResortPreKey; - final Optional pniPqLastResortPreKey; + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); @@ -620,10 +621,10 @@ class RegistrationControllerTest { void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, final IdentityKey expectedAciIdentityKey, final IdentityKey expectedPniIdentityKey, - final SignedPreKey expectedAciSignedPreKey, - final SignedPreKey expectedPniSignedPreKey, - final SignedPreKey expectedAciPqLastResortPreKey, - final SignedPreKey expectedPniPqLastResortPreKey, + final ECSignedPreKey expectedAciSignedPreKey, + final ECSignedPreKey expectedPniSignedPreKey, + final KEMSignedPreKey expectedAciPqLastResortPreKey, + final KEMSignedPreKey expectedPniPqLastResortPreKey, final Optional expectedApnsToken, final Optional expectedApnsVoipToken, final Optional expectedGcmToken) throws InterruptedException { @@ -686,10 +687,10 @@ class RegistrationControllerTest { private static Stream atomicAccountCreationSuccess() { final Optional aciIdentityKey; final Optional pniIdentityKey; - final Optional aciSignedPreKey; - final Optional pniSignedPreKey; - final Optional aciPqLastResortPreKey; - final Optional pniPqLastResortPreKey; + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index bced40252..628877ee8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -29,7 +29,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -149,17 +149,17 @@ class AccountsManagerChangeNumberIntegrationTest { final String secondNumber = "+18005552222"; final int rotatedPniRegistrationId = 17; final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final SignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair); + final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities()); final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>()); - account.getMasterDevice().orElseThrow().setSignedPreKey(new SignedPreKey()); + account.getMasterDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); + final Map preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 7260c5e4e..82f1c8fbc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -55,7 +55,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; @@ -673,9 +674,10 @@ class AccountsManagerTest { final String number = "+14152222222"; Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); assertThrows(IllegalArgumentException.class, () -> accountsManager.changeNumber( - account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), + account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)), "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); verify(accounts, never()).update(any()); @@ -719,10 +721,10 @@ class AccountsManagerTest { final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( + final Map newSignedKeys = Map.of( 1L, KeysHelper.signedECPreKey(1, identityKeyPair), 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( + final Map newSignedPqKeys = Map.of( 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); final Map newRegistrationIds = Map.of(1L, 201, 2L, 202); @@ -768,14 +770,14 @@ class AccountsManagerTest { List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - Map newSignedKeys = Map.of( + Map newSignedKeys = Map.of( 1L, KeysHelper.signedECPreKey(1, identityKeyPair), 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); Map newRegistrationIds = Map.of(1L, 201, 2L, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); + Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -810,10 +812,10 @@ class AccountsManagerTest { List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( + final Map newSignedKeys = Map.of( 1L, KeysHelper.signedECPreKey(1, identityKeyPair), 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( + final Map newSignedPqKeys = Map.of( 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); Map newRegistrationIds = Map.of(1L, 201, 2L, 202); @@ -823,7 +825,7 @@ class AccountsManagerTest { when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); - Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); + Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 45b0d0094..1c8877943 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -1023,9 +1023,7 @@ class AccountsTest { assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId()); assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId()); assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen()); - assertThat(resultDevice.getSignedPreKey().getPublicKey()).isEqualTo(expectingDevice.getSignedPreKey().getPublicKey()); - assertThat(resultDevice.getSignedPreKey().getKeyId()).isEqualTo(expectingDevice.getSignedPreKey().getKeyId()); - assertThat(resultDevice.getSignedPreKey().getSignature()).isEqualTo(expectingDevice.getSignedPreKey().getSignature()); + assertThat(resultDevice.getSignedPreKey()).isEqualTo(expectingDevice.getSignedPreKey()); assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages()); assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent()); assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index fcad08cb0..b1d702f2f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -28,11 +28,15 @@ import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; public class ChangeNumberManagerTest { private AccountsManager accountsManager; @@ -106,8 +110,9 @@ public class ChangeNumberManagerTest { void changeNumberSetPrimaryDevicePrekey() throws Exception { Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); - var prekeys = Map.of(1L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); @@ -133,8 +138,9 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); @@ -176,9 +182,10 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); - final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); @@ -218,9 +225,10 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); - final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); @@ -258,8 +266,9 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); @@ -297,9 +306,10 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); - final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); @@ -344,7 +354,10 @@ public class ChangeNumberManagerTest { new IncomingMessage(1, 2, 1, "foo"), new IncomingMessage(1, 3, 1, "foo")); - final Map preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); + + final Map preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, @@ -374,7 +387,10 @@ public class ChangeNumberManagerTest { new IncomingMessage(1, 2, 1, "foo"), new IncomingMessage(1, 3, 1, "foo")); - final Map preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); + + final Map preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeviceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeviceTest.java index 237873593..44f25921a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeviceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DeviceTest.java @@ -13,14 +13,14 @@ import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; class DeviceTest { @ParameterizedTest @MethodSource void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId, - final SignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) { + final ECSignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) { final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis(); @@ -41,36 +41,36 @@ class DeviceTest { // master fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled Arguments.of(true, false, null, null, null, Duration.ofDays(60), false), Arguments.of(true, false, null, null, null, Duration.ofDays(1), false), - Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), - Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false), + Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false), + Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false), Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false), Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false), - Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), true), - Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true), + Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), true), + Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false), Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false), - Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), true), - Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true), + Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), true), + Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, true, null, null, null, Duration.ofDays(60), false), Arguments.of(true, true, null, null, null, Duration.ofDays(1), false), - Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), true), - Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true), + Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), true), + Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, false, null, null, null, Duration.ofDays(60), false), Arguments.of(false, false, null, null, null, Duration.ofDays(1), false), - Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), - Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false), + Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false), + Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false), Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false), Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false), - Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), false), - Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true), + Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), false), + Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false), Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false), - Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), false), - Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true), + Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), false), + Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, true, null, null, null, Duration.ofDays(60), false), Arguments.of(false, true, null, null, null, Duration.ofDays(1), false), - Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), - Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true) + Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false), + Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true) ); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index 83b19a07f..a114e5dc6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.security.SecureRandom; import java.util.List; import java.util.Map; import java.util.Optional; @@ -21,8 +20,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -37,6 +37,8 @@ class KeysManagerTest { private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final long DEVICE_ID = 1L; + private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + @BeforeEach void setup() { keysManager = new KeysManager( @@ -62,17 +64,17 @@ class KeysManagerTest { assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Repeatedly storing same key should have no effect"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ prekeys should have no effect on EC prekeys"); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001)); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestKEMSignedPreKey(1001)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); - assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); + assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId()); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), @@ -80,7 +82,7 @@ class KeysManagerTest { assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new EC prekeys should have no effect on PQ prekeys"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), @@ -88,13 +90,12 @@ class KeysManagerTest { keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(4), generateTestPreKey(5)), - List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)), - generateTestSignedPreKey(1002)); + List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(1002)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(), + assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(), "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); } @@ -102,10 +103,10 @@ class KeysManagerTest { void testTakeAccountAndDeviceId() { assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); - final PreKey preKey = generateTestPreKey(1); + final ECPreKey preKey = generateTestPreKey(1); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); - final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); + final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); assertEquals(Optional.of(preKey), takenKey); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); } @@ -114,9 +115,9 @@ class KeysManagerTest { void testTakePQ() { assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); - final SignedPreKey preKey1 = generateTestSignedPreKey(1); - final SignedPreKey preKey2 = generateTestSignedPreKey(2); - final SignedPreKey preKeyLast = generateTestSignedPreKey(1001); + final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1); + final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2); + final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); @@ -138,7 +139,7 @@ class KeysManagerTest { assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null); + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); } @@ -147,13 +148,11 @@ class KeysManagerTest { void testDeleteByAccount() { keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), - generateTestSignedPreKey(5)); + List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5)); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(generateTestPreKey(6)), - List.of(generateTestSignedPreKey(7)), - generateTestSignedPreKey(8)); + List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); @@ -176,13 +175,11 @@ class KeysManagerTest { void testDeleteByAccountAndDevice() { keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), - generateTestSignedPreKey(5)); + List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5)); keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(generateTestPreKey(6)), - List.of(generateTestSignedPreKey(7)), - generateTestSignedPreKey(8)); + List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); @@ -211,17 +208,17 @@ class KeysManagerTest { ACCOUNT_UUID, Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); - assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId()); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId()); + assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId()); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent()); keysManager.storePqLastResort( ACCOUNT_UUID, Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates"); - assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone"); - assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone"); + assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones"); } @Test @@ -237,21 +234,15 @@ class KeysManagerTest { Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID))); } - private static PreKey generateTestPreKey(final long keyId) { - final byte[] key = new byte[32]; - new SecureRandom().nextBytes(key); - - return new PreKey(keyId, key); + private static ECPreKey generateTestPreKey(final long keyId) { + return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); } - private static SignedPreKey generateTestSignedPreKey(final long keyId) { - final byte[] key = new byte[32]; - final byte[] signature = new byte[32]; + private static ECSignedPreKey generateTestECSignedPreKey(final long keyId) { + return KeysHelper.signedECPreKey(keyId, IDENTITY_KEY_PAIR); + } - final SecureRandom secureRandom = new SecureRandom(); - secureRandom.nextBytes(key); - secureRandom.nextBytes(signature); - - return new SignedPreKey(keyId, key, signature); + private static KEMSignedPreKey generateTestKEMSignedPreKey(final long keyId) { + return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java new file mode 100644 index 000000000..e7ed8205b --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; + +import static org.junit.jupiter.api.Assertions.*; + +class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest { + + private RepeatedUseKEMSignedPreKeyStore keyStore; + + private int currentKeyId = 1; + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = + new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS); + + private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + + @BeforeEach + void setUp() { + keyStore = new RepeatedUseKEMSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); + } + + @Override + protected RepeatedUseSignedPreKeyStore getKeyStore() { + return keyStore; + } + + @Override + protected KEMSignedPreKey generateSignedPreKey() { + return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 40a4757d8..456b24450 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -5,59 +5,31 @@ package org.whispersystems.textsecuregcm.storage; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.reactivestreams.Subscriber; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.tests.util.KeysHelper; -import org.whispersystems.textsecuregcm.util.AttributeValues; -import reactor.core.publisher.Flux; -import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; -import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; -import software.amazon.awssdk.services.dynamodb.model.QueryRequest; -import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.Map; import java.util.Optional; import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +abstract class RepeatedUseSignedPreKeyStoreTest> { -class RepeatedUseSignedPreKeyStoreTest { + protected abstract RepeatedUseSignedPreKeyStore getKeyStore(); - private RepeatedUseSignedPreKeyStore keys; - - private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); - - @RegisterExtension - static final DynamoDbExtension DYNAMO_DB_EXTENSION = - new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS); - - @BeforeEach - void setUp() { - keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), - DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); - } + protected abstract K generateSignedPreKey(); @Test void storeFind() { + final RepeatedUseSignedPreKeyStore keys = getKeyStore(); + assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join()); { final UUID identifier = UUID.randomUUID(); final long deviceId = 1; - final SignedPreKey signedPreKey = generateSignedPreKey(); + final K signedPreKey = generateSignedPreKey(); assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join()); @@ -65,7 +37,7 @@ class RepeatedUseSignedPreKeyStoreTest { { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( + final Map signedPreKeys = Map.of( 1L, generateSignedPreKey(), 2L, generateSignedPreKey() ); @@ -78,11 +50,13 @@ class RepeatedUseSignedPreKeyStoreTest { @Test void delete() { + final RepeatedUseSignedPreKeyStore keys = getKeyStore(); + assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( + final Map signedPreKeys = Map.of( 1L, generateSignedPreKey(), 2L, generateSignedPreKey() ); @@ -96,7 +70,7 @@ class RepeatedUseSignedPreKeyStoreTest { { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( + final Map signedPreKeys = Map.of( 1L, generateSignedPreKey(), 2L, generateSignedPreKey() ); @@ -108,42 +82,4 @@ class RepeatedUseSignedPreKeyStoreTest { assertEquals(Optional.empty(), keys.find(identifier, 2).join()); } } - - @Test - void deleteWithError() { - final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class); - final QueryPublisher queryPublisher = mock(QueryPublisher.class); - - final SdkPublisher> itemPublisher = new SdkPublisher>() { - final Flux> items = Flux.just( - Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)), - Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2))); - - @Override - public void subscribe(final Subscriber> subscriber) { - items.subscribe(subscriber); - } - }; - - when(queryPublisher.items()).thenReturn(itemPublisher); - when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher); - - final Exception deleteItemException = new IllegalArgumentException("OH NO"); - - when(mockClient.deleteItem(any(DeleteItemRequest.class))) - .thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build())) - .thenReturn(CompletableFuture.failedFuture(deleteItemException)); - - final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient, - DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); - - final CompletionException completionException = - assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join()); - - assertEquals(deleteItemException, completionException.getCause()); - } - - private static SignedPreKey generateSignedPreKey() { - return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java index ca6b14654..52044e4f9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java @@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.storage; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; -import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; -class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { +class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { private SingleUseECPreKeyStore preKeyStore; @@ -24,12 +24,12 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { } @Override - protected SingleUsePreKeyStore getPreKeyStore() { + protected SingleUsePreKeyStore getPreKeyStore() { return preKeyStore; } @Override - protected PreKey generatePreKey(final long keyId) { - return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize()); + protected ECPreKey generatePreKey(final long keyId) { + return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java index b0df5cd31..e21685f3c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java @@ -9,10 +9,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; -class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest { +class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest { private SingleUseKEMPreKeyStore preKeyStore; @@ -28,12 +28,12 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest } @Override - protected SingleUsePreKeyStore getPreKeyStore() { + protected SingleUsePreKeyStore getPreKeyStore() { return preKeyStore; } @Override - protected SignedPreKey generatePreKey(final long keyId) { + protected KEMSignedPreKey generatePreKey(final long keyId) { return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java index dd8563ceb..894be6367 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -24,7 +24,7 @@ import org.whispersystems.textsecuregcm.entities.PreKey; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -abstract class SingleUsePreKeyStoreTest { +abstract class SingleUsePreKeyStoreTest> { private static final int KEY_COUNT = 100; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index a693ff785..3d6e47b53 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -54,9 +54,10 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceResponse; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; @@ -265,10 +266,10 @@ class DeviceControllerTest { assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); - final Optional aciSignedPreKey; - final Optional pniSignedPreKey; - final Optional aciPqLastResortPreKey; - final Optional pniPqLastResortPreKey; + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); @@ -351,10 +352,10 @@ class DeviceControllerTest { assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); - final Optional aciSignedPreKey; - final Optional pniSignedPreKey; - final Optional aciPqLastResortPreKey; - final Optional pniPqLastResortPreKey; + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); @@ -395,10 +396,10 @@ class DeviceControllerTest { @SuppressWarnings("OptionalUsedAsFieldOrParameterType") void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey, final IdentityKey pniIdentityKey, - final Optional aciSignedPreKey, - final Optional pniSignedPreKey, - final Optional aciPqLastResortPreKey, - final Optional pniPqLastResortPreKey) { + final Optional aciSignedPreKey, + final Optional pniSignedPreKey, + final Optional aciPqLastResortPreKey, + final Optional pniPqLastResortPreKey) { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); @@ -435,10 +436,10 @@ class DeviceControllerTest { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final Optional aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); - final Optional pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Optional aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); - final Optional pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + final Optional aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); + final Optional pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Optional aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); + final Optional pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); @@ -455,10 +456,10 @@ class DeviceControllerTest { @MethodSource void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey, final IdentityKey pniIdentityKey, - final SignedPreKey aciSignedPreKey, - final SignedPreKey pniSignedPreKey, - final SignedPreKey aciPqLastResortPreKey, - final SignedPreKey pniPqLastResortPreKey) { + final ECSignedPreKey aciSignedPreKey, + final ECSignedPreKey pniSignedPreKey, + final KEMSignedPreKey aciPqLastResortPreKey, + final KEMSignedPreKey pniPqLastResortPreKey) { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); @@ -495,25 +496,31 @@ class DeviceControllerTest { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final SignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); - final SignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); - final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); - final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); return Stream.of( - Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), - Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, signedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey), - Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, signedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey), - Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, signedPreKeyWithBadSignature(pniPqLastResortPreKey)) + Arguments.of(aciIdentityKey, pniIdentityKey, ecSignedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, ecSignedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, kemSignedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey), + Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, kemSignedPreKeyWithBadSignature(pniPqLastResortPreKey)) ); } - private static SignedPreKey signedPreKeyWithBadSignature(final SignedPreKey signedPreKey) { - return new SignedPreKey(signedPreKey.getKeyId(), - signedPreKey.getPublicKey(), + private static ECSignedPreKey ecSignedPreKeyWithBadSignature(final ECSignedPreKey signedPreKey) { + return new ECSignedPreKey(signedPreKey.keyId(), + signedPreKey.publicKey(), + "incorrect-signature".getBytes(StandardCharsets.UTF_8)); + } + + private static KEMSignedPreKey kemSignedPreKeyWithBadSignature(final KEMSignedPreKey signedPreKey) { + return new KEMSignedPreKey(signedPreKey.keyId(), + signedPreKey.publicKey(), "incorrect-signature".getBytes(StandardCharsets.UTF_8)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index edbce20b4..435922cda 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -20,6 +21,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -47,6 +50,9 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyResponse; @@ -63,6 +69,7 @@ import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; @ExtendWith(DropwizardExtensionsSupport.class) class KeysControllerTest { @@ -86,27 +93,27 @@ class KeysControllerTest { private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()); - private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); - private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); - private final PreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334); - private final PreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336); + private final ECPreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); + private final ECPreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); + private final ECPreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334); + private final ECPreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336); - private final PreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777); + private final ECPreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777); - private final SignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair()); - private final SignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair()); - private final SignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair()); + private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair()); + private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair()); + private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair()); - private final SignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair()); + private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair()); - private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR); - private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR); - private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR); + private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR); + private final ECSignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR); + private final ECSignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR); private final static KeysManager KEYS = mock(KeysManager.class ); private final static AccountsManager accounts = mock(AccountsManager.class ); @@ -127,6 +134,42 @@ class KeysControllerTest { private Device sampleDevice; + private record WeaklyTypedPreKey(long keyId, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] publicKey) { + + static WeaklyTypedPreKey fromPreKey(final PreKey preKey) { + return new WeaklyTypedPreKey(preKey.keyId(), preKey.serializedPublicKey()); + } + } + + private record WeaklyTypedSignedPreKey(long keyId, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] publicKey, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] signature) { + + static WeaklyTypedSignedPreKey fromSignedPreKey(final SignedPreKey signedPreKey) { + return new WeaklyTypedSignedPreKey(signedPreKey.keyId(), signedPreKey.serializedPublicKey(), signedPreKey.signature()); + } + } + + private record WeaklyTypedPreKeyState(List preKeys, + WeaklyTypedSignedPreKey signedPreKey, + List pqPreKeys, + WeaklyTypedSignedPreKey pqLastResortPreKey, + + @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + byte[] identityKey) { + } + @BeforeEach void setup() { sampleDevice = mock(Device.class); @@ -228,30 +271,30 @@ class KeysControllerTest { @Test void getSignedPreKeyV2() { - SignedPreKey result = resources.getJerseyTest() + ECSignedPreKey result = resources.getJerseyTest() .target("/v2/keys/signed") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(SignedPreKey.class); + .get(ECSignedPreKey.class); - assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result); + assertEquals(VALID_DEVICE_SIGNED_KEY, result); } @Test void getPhoneNumberIdentifierSignedPreKeyV2() { - SignedPreKey result = resources.getJerseyTest() + ECSignedPreKey result = resources.getJerseyTest() .target("/v2/keys/signed") .queryParam("identity", "pni") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(SignedPreKey.class); + .get(ECSignedPreKey.class); - assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result); + assertEquals(VALID_DEVICE_PNI_SIGNED_KEY, result); } @Test void putSignedPreKeyV2() { - SignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR); + ECSignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR); Response response = resources.getJerseyTest() .target("/v2/keys/signed") .request() @@ -267,7 +310,7 @@ class KeysControllerTest { @Test void putPhoneNumberIdentitySignedPreKeyV2() { - final SignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR); + final ECSignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR); Response response = resources.getJerseyTest() .target("/v2/keys/signed") @@ -285,7 +328,7 @@ class KeysControllerTest { @Test void disabledPutSignedPreKeyV2() { - SignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR); + ECSignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR); Response response = resources.getJerseyTest() .target("/v2/keys/signed") .request() @@ -305,10 +348,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); @@ -316,7 +359,7 @@ class KeysControllerTest { @Test void validSingleRequestPqTestNoPqKeysV2() { - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty()); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty()); PreKeyResponse result = resources.getJerseyTest() .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) @@ -327,10 +370,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); @@ -348,10 +391,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); @@ -368,10 +411,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); @@ -388,10 +431,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, 1); verify(KEYS).takePQ(EXISTS_PNI, 1); @@ -410,10 +453,10 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); @@ -445,9 +488,9 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); - assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); + assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1); @@ -510,14 +553,14 @@ class KeysControllerTest { assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); - PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); - PreKey preKey = results.getDevice(1).getPreKey(); + ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey(); + ECPreKey preKey = results.getDevice(1).getPreKey(); long registrationId = results.getDevice(1).getRegistrationId(); long deviceId = results.getDevice(1).getDeviceId(); - assertKeysMatch(SAMPLE_KEY, preKey); + assertEquals(SAMPLE_KEY, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); + assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); assertThat(deviceId).isEqualTo(1); signedPreKey = results.getDevice(2).getSignedPreKey(); @@ -525,9 +568,9 @@ class KeysControllerTest { registrationId = results.getDevice(2).getRegistrationId(); deviceId = results.getDevice(2).getDeviceId(); - assertKeysMatch(SAMPLE_KEY2, preKey); + assertEquals(SAMPLE_KEY2, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); - assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); + assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); assertThat(deviceId).isEqualTo(2); signedPreKey = results.getDevice(4).getSignedPreKey(); @@ -535,7 +578,7 @@ class KeysControllerTest { registrationId = results.getDevice(4).getRegistrationId(); deviceId = results.getDevice(4).getDeviceId(); - assertKeysMatch(SAMPLE_KEY4, preKey); + assertEquals(SAMPLE_KEY4, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(signedPreKey).isNull(); assertThat(deviceId).isEqualTo(4); @@ -554,7 +597,7 @@ class KeysControllerTest { when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2)); when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3)); - when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty()); + when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty()); PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) @@ -566,16 +609,16 @@ class KeysControllerTest { assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); - PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); - PreKey preKey = results.getDevice(1).getPreKey(); - SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey(); + ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey(); + ECPreKey preKey = results.getDevice(1).getPreKey(); + KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey(); long registrationId = results.getDevice(1).getRegistrationId(); long deviceId = results.getDevice(1).getDeviceId(); - assertKeysMatch(SAMPLE_KEY, preKey); - assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey); + assertEquals(SAMPLE_KEY, preKey); + assertEquals(SAMPLE_PQ_KEY, pqPreKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); - assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); + assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); assertThat(deviceId).isEqualTo(1); signedPreKey = results.getDevice(2).getSignedPreKey(); @@ -585,9 +628,9 @@ class KeysControllerTest { deviceId = results.getDevice(2).getDeviceId(); assertThat(preKey).isNull(); - assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey); + assertEquals(SAMPLE_PQ_KEY2, pqPreKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); - assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); + assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); assertThat(deviceId).isEqualTo(2); signedPreKey = results.getDevice(4).getSignedPreKey(); @@ -596,7 +639,7 @@ class KeysControllerTest { registrationId = results.getDevice(4).getRegistrationId(); deviceId = results.getDevice(4).getDeviceId(); - assertKeysMatch(SAMPLE_KEY4, preKey); + assertEquals(SAMPLE_KEY4, preKey); assertThat(pqPreKey).isNull(); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(signedPreKey).isNull(); @@ -656,9 +699,9 @@ class KeysControllerTest { @Test void putKeysTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); @@ -672,7 +715,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); assertThat(listCaptor.getValue()).containsExactly(preKey); @@ -684,11 +727,11 @@ class KeysControllerTest { @Test void putKeysPqTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); - final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); - final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); + final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); + final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); @@ -702,8 +745,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); - ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); @@ -718,8 +761,9 @@ class KeysControllerTest { void putKeysStructurallyInvalidSignedECKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair); - final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null); + final KEMSignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair); + final WeaklyTypedPreKeyState preKeyState = + new WeaklyTypedPreKeyState(null, WeaklyTypedSignedPreKey.fromSignedPreKey(wrongPreKey), null, null, identityKey.serialize()); Response response = resources.getJerseyTest() @@ -728,15 +772,16 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(422); + assertThat(response.getStatus()).isEqualTo(400); } @Test void putKeysStructurallyInvalidUnsignedECKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes()); - final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null); + final WeaklyTypedPreKey wrongPreKey = new WeaklyTypedPreKey(1, "cluck cluck i'm a parrot".getBytes()); + final WeaklyTypedPreKeyState preKeyState = + new WeaklyTypedPreKeyState(List.of(wrongPreKey), null, null, null, identityKey.serialize()); Response response = resources.getJerseyTest() @@ -745,15 +790,16 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(422); + assertThat(response.getStatus()).isEqualTo(400); } @Test void putKeysStructurallyInvalidPQOneTimeKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); - final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null); + final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair)); + final WeaklyTypedPreKeyState preKeyState = + new WeaklyTypedPreKeyState(null, null, List.of(wrongPreKey), null, identityKey.serialize()); Response response = resources.getJerseyTest() @@ -762,15 +808,16 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(422); + assertThat(response.getStatus()).isEqualTo(400); } @Test void putKeysStructurallyInvalidPQLastResortKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); - final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); - final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey); + final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair)); + final WeaklyTypedPreKeyState preKeyState = + new WeaklyTypedPreKeyState(null, null, null, wrongPreKey, identityKey.serialize()); Response response = resources.getJerseyTest() @@ -779,14 +826,14 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); - assertThat(response.getStatus()).isEqualTo(422); + assertThat(response.getStatus()).isEqualTo(400); } @Test void putKeysByPhoneNumberIdentifierTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); @@ -801,7 +848,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull()); assertThat(listCaptor.getValue()).containsExactly(preKey); @@ -813,11 +860,11 @@ class KeysControllerTest { @Test void putKeysByPhoneNumberIdentifierPqTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); - final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); - final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); + final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); + final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); @@ -832,8 +879,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); - ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); @@ -846,7 +893,7 @@ class KeysControllerTest { @Test void putPrekeyWithInvalidSignature() { - final SignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair()); + final ECSignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair()); PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, badSignedPreKey, List.of()); Response response = resources.getJerseyTest() @@ -861,9 +908,9 @@ class KeysControllerTest { @Test void disabledPutKeysTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); @@ -877,13 +924,13 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); - List capturedList = listCaptor.getValue(); + List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); - assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); - assertThat(capturedList.get(0).getPublicKey()).isEqualTo(preKey.getPublicKey()); + assertThat(capturedList.get(0).keyId()).isEqualTo(31337); + assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey()); verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); @@ -892,10 +939,10 @@ class KeysControllerTest { @Test void putIdentityKeyNonPrimary() { - final PreKey preKey = KeysHelper.ecPreKey(31337); - final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR); + final ECPreKey preKey = KeysHelper.ecPreKey(31337); + final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR); - List preKeys = List.of(preKey); + List preKeys = List.of(preKey); PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, signedPreKey, preKeys); @@ -908,13 +955,4 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(403); } - - private void assertKeysMatch(PreKey expected, PreKey actual) { - assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId()); - assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey()); - if (expected instanceof final SignedPreKey signedExpected) { - final SignedPreKey signedActual = (SignedPreKey) actual; - assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature()); - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/entities/PreKeyTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/entities/PreKeyTest.java index 87c3ea109..1dfeee56d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/entities/PreKeyTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/entities/PreKeyTest.java @@ -12,7 +12,8 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.entities.PreKey; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; import java.util.Base64; @@ -22,7 +23,7 @@ class PreKeyTest { @Test void serializeToJSONV2() throws Exception { - PreKey preKey = new PreKey(1234, PUBLIC_KEY); + ECPreKey preKey = new ECPreKey(1234, new ECPublicKey(PUBLIC_KEY)); assertThat("PreKeyV2 Serialization works", asJson(preKey), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java index 86601867c..4ee2ef2ad 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java @@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.util; import java.util.Random; import org.signal.libsignal.protocol.ecc.Curve; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Util; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/KeysHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/KeysHelper.java index bd9450be9..1217fbb02 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/KeysHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/KeysHelper.java @@ -7,26 +7,29 @@ package org.whispersystems.textsecuregcm.tests.util; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.kem.KEMKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyType; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.signal.libsignal.protocol.kem.KEMPublicKey; +import org.whispersystems.textsecuregcm.entities.ECPreKey; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; public final class KeysHelper { - public static PreKey ecPreKey(final long id) { - return new PreKey(id, Curve.generateKeyPair().getPublicKey().serialize()); + public static ECPreKey ecPreKey(final long id) { + return new ECPreKey(id, Curve.generateKeyPair().getPublicKey()); } - public static SignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) { - final byte[] pubKey = Curve.generateKeyPair().getPublicKey().serialize(); - final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); - return new SignedPreKey(id, pubKey, sig); + public static ECSignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) { + final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey(); + final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize()); + return new ECSignedPreKey(id, pubKey, sig); } - public static SignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) { - final byte[] pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey().serialize(); - final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); - return new SignedPreKey(id, pubKey, sig); + public static KEMSignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) { + final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey(); + final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize()); + return new KEMSignedPreKey(id, pubKey, sig); } }