Use strongly-typed pre-keys

This commit is contained in:
Jon Chambers 2023-06-09 10:08:49 -04:00 committed by GitHub
parent b27334b0ff
commit 17aa5d8e74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 805 additions and 719 deletions

View File

@ -32,15 +32,18 @@ import org.signal.integration.config.Config;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.kem.KEMKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyType; import org.signal.libsignal.protocol.kem.KEMKeyType;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient; import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
@ -325,15 +328,15 @@ public final class Operations {
} }
} }
private static SignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) { private static ECSignedPreKey generateSignedECPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = Curve.generateKeyPair().getPublicKey().serialize(); final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new SignedPreKey(id, pubKey, sig); return new ECSignedPreKey(id, pubKey, sig);
} }
private static SignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) { private static KEMSignedPreKey generateSignedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey().serialize(); final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new SignedPreKey(id, pubKey, sig); return new KEMSignedPreKey(id, pubKey, sig);
} }
} }

View File

@ -41,7 +41,9 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem; import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem;
@ -207,9 +209,9 @@ public class KeysController {
for (Device device : devices) { for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); ECSignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); ECPreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ? final int registrationId = usePhoneNumberIdentity ?
@ -234,7 +236,7 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState @ChangesDeviceEnabledState
public void setSignedKey(@Auth final AuthenticatedAccount auth, public void setSignedKey(@Auth final AuthenticatedAccount auth,
@Valid final SignedPreKey signedPreKey, @Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") final Optional<String> identityType) {
Device device = auth.getAuthenticatedDevice(); Device device = auth.getAuthenticatedDevice();
@ -252,11 +254,11 @@ public class KeysController {
@GET @GET
@Path("/signed") @Path("/signed")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Optional<SignedPreKey> getSignedKey(@Auth final AuthenticatedAccount auth, public Optional<ECSignedPreKey> getSignedKey(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") final Optional<String> identityType) {
Device device = auth.getAuthenticatedDevice(); Device device = auth.getAuthenticatedDevice();
SignedPreKey signedPreKey = usePhoneNumberIdentity(identityType) ? ECSignedPreKey signedPreKey = usePhoneNumberIdentity(identityType) ?
device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
return Optional.ofNullable(signedPreKey); return Optional.ofNullable(signedPreKey);

View File

@ -20,8 +20,6 @@ import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record ChangeNumberRequest( public record ChangeNumberRequest(
@Schema(description=""" @Schema(description="""
@ -54,7 +52,7 @@ public record ChangeNumberRequest(
@Schema(description=""" @Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one. A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Long, @NotNull @Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> devicePniSignedPrekeys, @NotNull @Valid Map<Long, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description=""" @Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
@ -62,14 +60,14 @@ public record ChangeNumberRequest(
If present, must contain one prekey per enabled device including this one. If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> devicePniPqLastResortPrekeys, @Valid Map<Long, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest { @NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
List<SignedPreKey> spks = new ArrayList<>(); List<SignedPreKey<?>> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) { if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values()); spks.addAll(devicePniSignedPrekeys.values());
} }

View File

@ -19,8 +19,6 @@ import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record ChangePhoneNumberRequest( public record ChangePhoneNumberRequest(
@Schema(description="the new phone number for this account") @Schema(description="the new phone number for this account")
@ -46,7 +44,7 @@ public record ChangePhoneNumberRequest(
@Schema(description=""" @Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one. A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable Map<Long, @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> devicePniSignedPrekeys, @Nullable Map<Long, ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description=""" @Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
@ -54,14 +52,14 @@ public record ChangePhoneNumberRequest(
If present, must contain one prekey per enabled device including this one. If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable @Valid Map<Long, @NotNull @Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> devicePniPqLastResortPrekeys, @Nullable @Valid Map<Long, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@Nullable Map<Long, Integer> pniRegistrationIds) { @Nullable Map<Long, Integer> pniRegistrationIds) {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
List<SignedPreKey> spks = new ArrayList<>(); List<SignedPreKey<?>> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) { if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values()); spks.addAll(devicePniSignedPrekeys.values());
} }

View File

@ -4,9 +4,6 @@ import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.Valid; import javax.validation.Valid;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
import java.util.Optional; import java.util.Optional;
public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
@ -14,28 +11,28 @@ public record DeviceActivationRequest(@Schema(requiredMode = Schema.RequiredMode
will be created "atomically," and all other properties needed for atomic account will be created "atomically," and all other properties needed for atomic account
creation must also be present. creation must also be present.
""") """)
Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
A signed EC pre-key to be associated with this account's PNI. If provided, an account A signed EC pre-key to be associated with this account's PNI. If provided, an account
will be created "atomically," and all other properties needed for atomic account will be created "atomically," and all other properties needed for atomic account
creation must also be present. creation must also be present.
""") """)
Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
A signed Kyber-1024 "last resort" pre-key to be associated with this account's ACI. If A signed Kyber-1024 "last resort" pre-key to be associated with this account's ACI. If
provided, an account will be created "atomically," and all other properties needed for provided, an account will be created "atomically," and all other properties needed for
atomic account creation must also be present. atomic account creation must also be present.
""") """)
Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciPqLastResortPreKey, Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
A signed Kyber-1024 "last resort" pre-key to be associated with this account's PNI. If A signed Kyber-1024 "last resort" pre-key to be associated with this account's PNI. If
provided, an account will be created "atomically," and all other properties needed for provided, an account will be created "atomically," and all other properties needed for
atomic account creation must also be present. atomic account creation must also be present.
""") """)
Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniPqLastResortPreKey, Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
An APNs token set for the account's primary device. If provided, the account's primary An APNs token set for the account's primary device. If provided, the account's primary

View File

@ -0,0 +1,22 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
public record ECPreKey(long keyId,
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)
@JsonDeserialize(using = ECPublicKeyAdapter.Deserializer.class)
ECPublicKey publicKey) implements PreKey<ECPublicKey> {
@Override
public byte[] serializedPublicKey() {
return publicKey().serialize();
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
import java.util.Arrays;
import java.util.Objects;
public record ECSignedPreKey(long keyId,
@JsonSerialize(using = ECPublicKeyAdapter.Serializer.class)
@JsonDeserialize(using = ECPublicKeyAdapter.Deserializer.class)
ECPublicKey publicKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] signature) implements SignedPreKey<ECPublicKey> {
@Override
public byte[] serializedPublicKey() {
return publicKey().serialize();
}
@Override
public boolean equals(final Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
ECSignedPreKey that = (ECSignedPreKey) o;
return keyId == that.keyId && publicKey.equals(that.publicKey) && Arrays.equals(signature, that.signature);
}
@Override
public int hashCode() {
int result = Objects.hash(keyId, publicKey);
result = 31 * result + Arrays.hashCode(signature);
return result;
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.KEMPublicKeyAdapter;
import java.util.Arrays;
import java.util.Objects;
public record KEMSignedPreKey(long keyId,
@JsonSerialize(using = KEMPublicKeyAdapter.Serializer.class)
@JsonDeserialize(using = KEMPublicKeyAdapter.Deserializer.class)
KEMPublicKey publicKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] signature) implements SignedPreKey<KEMPublicKey> {
@Override
public byte[] serializedPublicKey() {
return publicKey().serialize();
}
@Override
public boolean equals(final Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
KEMSignedPreKey that = (KEMSignedPreKey) o;
return keyId == that.keyId && publicKey.equals(that.publicKey) && Arrays.equals(signature, that.signature);
}
@Override
public int hashCode() {
int result = Objects.hash(keyId, publicKey);
result = 31 * result + Arrays.hashCode(signature);
return result;
}
}

View File

@ -25,10 +25,10 @@ public record LinkDeviceRequest(@Schema(requiredMode = Schema.RequiredMode.REQUI
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode, public LinkDeviceRequest(@JsonProperty("verificationCode") String verificationCode,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("aciSignedPreKey") Optional<@Valid SignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid SignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid SignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@JsonProperty("pniPqLastResortPreKey") Optional<@Valid SignedPreKey> pniPqLastResortPreKey, @JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey,
@JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken,
@JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) {

View File

@ -15,8 +15,6 @@ import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record PhoneNumberIdentityKeyDistributionRequest( public record PhoneNumberIdentityKeyDistributionRequest(
@NotNull @NotNull
@ -37,7 +35,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
@Schema(description=""" @Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one. A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Long, @NotNull @Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> devicePniSignedPrekeys, Map<Long, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description=""" @Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
@ -45,7 +43,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
If present, must contain one prekey per enabled device including this one. If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""") Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> devicePniPqLastResortPrekeys, @Valid Map<Long, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@NotNull @NotNull
@Valid @Valid
@ -54,7 +52,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
List<SignedPreKey> spks = new ArrayList<>(devicePniSignedPrekeys.values()); List<SignedPreKey<?>> spks = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniPqLastResortPrekeys != null) { if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values()); spks.addAll(devicePniPqLastResortPrekeys.values());
} }

View File

@ -1,68 +1,15 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; public interface PreKey<K> {
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import javax.validation.constraints.NotEmpty; long keyId();
import javax.validation.constraints.NotNull;
import java.util.Arrays;
import java.util.Objects;
public class PreKey { K publicKey();
@JsonProperty byte[] serializedPublicKey();
@NotNull
private long keyId;
@JsonProperty
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
@NotEmpty
private byte[] publicKey;
public PreKey() {}
public PreKey(long keyId, byte[] publicKey)
{
this.keyId = keyId;
this.publicKey = publicKey;
}
public byte[] getPublicKey() {
return publicKey;
}
public void setPublicKey(byte[] publicKey) {
this.publicKey = publicKey;
}
public long getKeyId() {
return keyId;
}
public void setKeyId(long keyId) {
this.keyId = keyId;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PreKey preKey = (PreKey) o;
return keyId == preKey.keyId && Arrays.equals(publicKey, preKey.publicKey);
}
@Override
public int hashCode() {
int result = Objects.hash(keyId);
result = 31 * result + Arrays.hashCode(publicKey);
return result;
}
} }

View File

@ -20,20 +20,20 @@ public class PreKeyResponseItem {
@JsonProperty @JsonProperty
@Schema(description="the signed elliptic-curve prekey for the device, if one has been set") @Schema(description="the signed elliptic-curve prekey for the device, if one has been set")
private SignedPreKey signedPreKey; private ECSignedPreKey signedPreKey;
@JsonProperty @JsonProperty
@Schema(description="an unsigned elliptic-curve prekey for the device, if any remain") @Schema(description="an unsigned elliptic-curve prekey for the device, if any remain")
private PreKey preKey; private ECPreKey preKey;
@JsonProperty @JsonProperty
@Schema(description="a signed post-quantum prekey for the device " + @Schema(description="a signed post-quantum prekey for the device " +
"(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)") "(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)")
private SignedPreKey pqPreKey; private KEMSignedPreKey pqPreKey;
public PreKeyResponseItem() {} public PreKeyResponseItem() {}
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) { public PreKeyResponseItem(long deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey, KEMSignedPreKey pqPreKey) {
this.deviceId = deviceId; this.deviceId = deviceId;
this.registrationId = registrationId; this.registrationId = registrationId;
this.signedPreKey = signedPreKey; this.signedPreKey = signedPreKey;
@ -42,17 +42,17 @@ public class PreKeyResponseItem {
} }
@VisibleForTesting @VisibleForTesting
public SignedPreKey getSignedPreKey() { public ECSignedPreKey getSignedPreKey() {
return signedPreKey; return signedPreKey;
} }
@VisibleForTesting @VisibleForTesting
public PreKey getPreKey() { public ECPreKey getPreKey() {
return preKey; return preKey;
} }
@VisibleForTesting @VisibleForTesting
public SignedPreKey getPqPreKey() { public KEMSignedPreKey getPqPreKey() {
return pqPreKey; return pqPreKey;
} }

View File

@ -15,19 +15,13 @@ public abstract class PreKeySignatureValidator {
public static final Counter INVALID_SIGNATURE_COUNTER = public static final Counter INVALID_SIGNATURE_COUNTER =
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")); Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature"));
public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection<SignedPreKey> spks) { public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection<SignedPreKey<?>> spks) {
try { final boolean success = spks.stream().allMatch(spk -> spk.signatureValid(identityKey));
final boolean success = spks.stream()
.allMatch(spk -> identityKey.getPublicKey().verifySignature(spk.getPublicKey(), spk.getSignature()));
if (!success) { if (!success) {
INVALID_SIGNATURE_COUNTER.increment();
}
return success;
} catch (final IllegalArgumentException e) {
INVALID_SIGNATURE_COUNTER.increment(); INVALID_SIGNATURE_COUNTER.increment();
return false;
} }
return success;
} }
} }

View File

@ -16,8 +16,6 @@ import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public class PreKeyState { public class PreKeyState {
@ -26,16 +24,15 @@ public class PreKeyState {
@Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " + @Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " +
"If present and not empty, replaces all stored unsigned EC prekeys for the device; " + "If present and not empty, replaces all stored unsigned EC prekeys for the device; " +
"if absent or empty, any stored unsigned EC prekeys for the device are not deleted.") "if absent or empty, any stored unsigned EC prekeys for the device are not deleted.")
private List<@ValidPreKey(type=PreKeyType.ECC) PreKey> preKeys; private List<@Valid ECPreKey> preKeys;
@JsonProperty @JsonProperty
@Valid @Valid
@ValidPreKey(type=PreKeyType.ECC)
@Schema(description="An optional signed elliptic-curve prekey to use for this device. " + @Schema(description="An optional signed elliptic-curve prekey to use for this device. " +
"If present, replaces the stored signed elliptic-curve prekey for the device; " + "If present, replaces the stored signed elliptic-curve prekey for the device; " +
"if absent, the stored signed prekey is not deleted. " + "if absent, the stored signed prekey is not deleted. " +
"If present, must have a valid signature from the identity key in this request.") "If present, must have a valid signature from the identity key in this request.")
private SignedPreKey signedPreKey; private ECSignedPreKey signedPreKey;
@JsonProperty @JsonProperty
@Valid @Valid
@ -43,16 +40,15 @@ public class PreKeyState {
"Each key must have a valid signature from the identity key in this request. " + "Each key must have a valid signature from the identity key in this request. " +
"If present and not empty, replaces all stored unsigned PQ prekeys for the device; " + "If present and not empty, replaces all stored unsigned PQ prekeys for the device; " +
"if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.") "if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.")
private List<@ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> pqPreKeys; private List<@Valid KEMSignedPreKey> pqPreKeys;
@JsonProperty @JsonProperty
@Valid @Valid
@ValidPreKey(type=PreKeyType.KYBER)
@Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " + @Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " +
"If present, replaces the stored signed post-quantum last-resort prekey for the device; " + "If present, replaces the stored signed post-quantum last-resort prekey for the device; " +
"if absent, a stored last-resort prekey will *not* be deleted. " + "if absent, a stored last-resort prekey will *not* be deleted. " +
"If present, must have a valid signature from the identity key in this request.") "If present, must have a valid signature from the identity key in this request.")
private SignedPreKey pqLastResortPreKey; private KEMSignedPreKey pqLastResortPreKey;
@JsonProperty @JsonProperty
@JsonSerialize(using = IdentityKeyAdapter.Serializer.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@ -67,12 +63,12 @@ public class PreKeyState {
public PreKeyState() {} public PreKeyState() {}
@VisibleForTesting @VisibleForTesting
public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List<PreKey> keys) { public PreKeyState(IdentityKey identityKey, ECSignedPreKey signedPreKey, List<ECPreKey> keys) {
this(identityKey, signedPreKey, keys, null, null); this(identityKey, signedPreKey, keys, null, null);
} }
@VisibleForTesting @VisibleForTesting
public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) { public PreKeyState(IdentityKey identityKey, ECSignedPreKey signedPreKey, List<ECPreKey> keys, List<KEMSignedPreKey> pqKeys, KEMSignedPreKey pqLastResortKey) {
this.identityKey = identityKey; this.identityKey = identityKey;
this.signedPreKey = signedPreKey; this.signedPreKey = signedPreKey;
this.preKeys = keys; this.preKeys = keys;
@ -80,19 +76,19 @@ public class PreKeyState {
this.pqLastResortPreKey = pqLastResortKey; this.pqLastResortPreKey = pqLastResortKey;
} }
public List<PreKey> getPreKeys() { public List<ECPreKey> getPreKeys() {
return preKeys; return preKeys;
} }
public SignedPreKey getSignedPreKey() { public ECSignedPreKey getSignedPreKey() {
return signedPreKey; return signedPreKey;
} }
public List<SignedPreKey> getPqPreKeys() { public List<KEMSignedPreKey> getPqPreKeys() {
return pqPreKeys; return pqPreKeys;
} }
public SignedPreKey getPqLastResortPreKey() { public KEMSignedPreKey getPqLastResortPreKey() {
return pqLastResortPreKey; return pqLastResortPreKey;
} }
@ -102,7 +98,7 @@ public class PreKeyState {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedKey() { public boolean isSignatureValidOnEachSignedKey() {
List<SignedPreKey> spks = new ArrayList<>(); List<SignedPreKey<?>> spks = new ArrayList<>();
if (pqPreKeys != null) { if (pqPreKeys != null) {
spks.addAll(pqPreKeys); spks.addAll(pqPreKeys);
} }

View File

@ -20,8 +20,6 @@ import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ID of an existing verification session as it appears in a verification session The ID of an existing verification session as it appears in a verification session
@ -82,10 +80,10 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey, @JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey,
@JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey, @JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid ECSignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> aciPqLastResortPreKey,
@JsonProperty("pniPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> pniPqLastResortPreKey, @JsonProperty("pniPqLastResortPreKey") Optional<@Valid KEMSignedPreKey> pniPqLastResortPreKey,
@JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken, @JsonProperty("apnToken") Optional<@Valid ApnRegistrationId> apnToken,
@JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) { @JsonProperty("gcmToken") Optional<@Valid GcmRegistrationId> gcmToken) {
@ -106,7 +104,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private static boolean validatePreKeySignature(final Optional<IdentityKey> maybeIdentityKey, private static boolean validatePreKeySignature(final Optional<IdentityKey> maybeIdentityKey,
final Optional<SignedPreKey> maybeSignedPreKey) { final Optional<? extends SignedPreKey<?>> maybeSignedPreKey) {
return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey
.map(identityKey -> PreKeySignatureValidator.validatePreKeySignatures(identityKey, List.of(signedPreKey))) .map(identityKey -> PreKeySignatureValidator.validatePreKeySignatures(identityKey, List.of(signedPreKey)))

View File

@ -1,50 +1,17 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import org.signal.libsignal.protocol.IdentityKey;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import javax.validation.constraints.NotEmpty; public interface SignedPreKey<K> extends PreKey<K> {
import java.util.Arrays;
public class SignedPreKey extends PreKey { byte[] signature();
@JsonProperty default boolean signatureValid(final IdentityKey identityKey) {
@JsonSerialize(using = ByteArrayAdapter.Serializing.class) return identityKey.getPublicKey().verifySignature(serializedPublicKey(), signature());
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
@NotEmpty
private byte[] signature;
public SignedPreKey() {}
public SignedPreKey(long keyId, byte[] publicKey, byte[] signature) {
super(keyId, publicKey);
this.signature = signature;
}
public byte[] getSignature() {
return signature;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
SignedPreKey that = (SignedPreKey) o;
return Arrays.equals(signature, that.signature);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(signature);
return result;
} }
} }

View File

@ -45,7 +45,8 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -258,8 +259,8 @@ public class AccountsManager {
public Account changeNumber(final Account account, public Account changeNumber(final Account account,
final String targetNumber, final String targetNumber,
@Nullable final IdentityKey pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys, @Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { @Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber(); final String originalNumber = account.getNumber();
@ -350,8 +351,8 @@ public class AccountsManager {
public Account updatePniKeys(final Account account, public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final Map<Long, SignedPreKey> pniSignedPreKeys, final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys, @Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException { final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
@ -369,7 +370,7 @@ public class AccountsManager {
private boolean setPniKeys(final Account account, private boolean setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) { @Nullable final Map<Long, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return false; return false;
@ -383,7 +384,7 @@ public class AccountsManager {
if (!device.isEnabled()) { if (!device.isEnabled()) {
continue; continue;
} }
SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId()); ECSignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId()); int registrationId = pniRegistrationIds.get(device.getId());
changed = changed || changed = changed ||
!signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) || !signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) ||
@ -398,8 +399,8 @@ public class AccountsManager {
} }
private void validateDevices(final Account account, private void validateDevices(final Account account,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys, @Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException { @Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) { if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return; return;

View File

@ -20,9 +20,10 @@ import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
@ -41,8 +42,8 @@ public class ChangeNumberManager {
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@Nullable final IdentityKey pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys, @Nullable final Map<Long, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys, @Nullable final Map<Long, KEMSignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages, @Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds) @Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException { throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
@ -81,8 +82,8 @@ public class ChangeNumberManager {
public Account updatePniKeys(final Account account, public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final Map<Long, SignedPreKey> deviceSignedPreKeys, final Map<Long, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys, @Nullable final Map<Long, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages, final List<IncomingMessage> deviceMessages,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
validateDeviceMessages(account, deviceMessages); validateDeviceMessages(account, deviceMessages);

View File

@ -13,7 +13,7 @@ import java.util.stream.Collectors;
import java.util.stream.LongStream; import java.util.stream.LongStream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
public class Device { public class Device {
@ -61,10 +61,10 @@ public class Device {
private Integer phoneNumberIdentityRegistrationId; private Integer phoneNumberIdentityRegistrationId;
@JsonProperty @JsonProperty
private SignedPreKey signedPreKey; private ECSignedPreKey signedPreKey;
@JsonProperty("pniSignedPreKey") @JsonProperty("pniSignedPreKey")
private SignedPreKey phoneNumberIdentitySignedPreKey; private ECSignedPreKey phoneNumberIdentitySignedPreKey;
@JsonProperty @JsonProperty
private long lastSeen; private long lastSeen;
@ -230,19 +230,19 @@ public class Device {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
} }
public SignedPreKey getSignedPreKey() { public ECSignedPreKey getSignedPreKey() {
return signedPreKey; return signedPreKey;
} }
public void setSignedPreKey(SignedPreKey signedPreKey) { public void setSignedPreKey(ECSignedPreKey signedPreKey) {
this.signedPreKey = signedPreKey; this.signedPreKey = signedPreKey;
} }
public SignedPreKey getPhoneNumberIdentitySignedPreKey() { public ECSignedPreKey getPhoneNumberIdentitySignedPreKey() {
return phoneNumberIdentitySignedPreKey; return phoneNumberIdentitySignedPreKey;
} }
public void setPhoneNumberIdentitySignedPreKey(final SignedPreKey phoneNumberIdentitySignedPreKey) { public void setPhoneNumberIdentitySignedPreKey(final ECSignedPreKey phoneNumberIdentitySignedPreKey) {
this.phoneNumberIdentitySignedPreKey = phoneNumberIdentitySignedPreKey; this.phoneNumberIdentitySignedPreKey = phoneNumberIdentitySignedPreKey;
} }

View File

@ -13,15 +13,15 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
public class KeysManager { public class KeysManager {
private final SingleUseECPreKeyStore ecPreKeys; private final SingleUseECPreKeyStore ecPreKeys;
private final SingleUseKEMPreKeyStore pqPreKeys; private final SingleUseKEMPreKeyStore pqPreKeys;
private final RepeatedUseSignedPreKeyStore pqLastResortKeys; private final RepeatedUseKEMSignedPreKeyStore pqLastResortKeys;
public KeysManager( public KeysManager(
final DynamoDbAsyncClient dynamoDbAsyncClient, final DynamoDbAsyncClient dynamoDbAsyncClient,
@ -30,18 +30,18 @@ public class KeysManager {
final String pqLastResortTableName) { final String pqLastResortTableName) {
this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName); this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName);
this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName); this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName);
this.pqLastResortKeys = new RepeatedUseSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName); this.pqLastResortKeys = new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName);
} }
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) { public void store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) {
store(identifier, deviceId, keys, null, null); store(identifier, deviceId, keys, null, null);
} }
public void store( public void store(
final UUID identifier, final long deviceId, final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys, @Nullable final List<ECPreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys, @Nullable final List<KEMSignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) { @Nullable final KEMSignedPreKey pqLastResortKey) {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(); final List<CompletableFuture<Void>> storeFutures = new ArrayList<>();
@ -60,15 +60,15 @@ public class KeysManager {
CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join(); CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join();
} }
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) { public void storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) {
pqLastResortKeys.store(identifier, keys).join(); pqLastResortKeys.store(identifier, keys).join();
} }
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) { public Optional<ECPreKey> takeEC(final UUID identifier, final long deviceId) {
return ecPreKeys.take(identifier, deviceId).join(); return ecPreKeys.take(identifier, deviceId).join();
} }
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) { public Optional<KEMSignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return pqPreKeys.take(identifier, deviceId) return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
@ -76,9 +76,8 @@ public class KeysManager {
} }
@VisibleForTesting @VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) { Optional<KEMSignedPreKey> getLastResort(final UUID identifier, final long deviceId) {
return pqLastResortKeys.find(identifier, deviceId).join() return pqLastResortKeys.find(identifier, deviceId).join();
.map(signedPreKey -> signedPreKey);
} }
public List<Long> getPqEnabledDevices(final UUID identifier) { public List<Long> getPqEnabledDevices(final UUID identifier) {

View File

@ -0,0 +1,46 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map;
import java.util.UUID;
public class RepeatedUseKEMSignedPreKeyStore extends RepeatedUseSignedPreKeyStore<KEMSignedPreKey> {
public RepeatedUseKEMSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName);
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final KEMSignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID, getSortKey(deviceId),
ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.keyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature()));
}
@Override
protected KEMSignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
try {
return new KEMSignedPreKey(
Long.parseLong(item.get(ATTR_KEY_ID).n()),
new KEMPublicKey(item.get(ATTR_PUBLIC_KEY).b().asByteArray()),
item.get(ATTR_SIGNATURE).b().asByteArray());
} catch (final InvalidKeyException e) {
// This should never happen since we're serializing keys directly from `KEMPublicKey` instances on the way in
throw new IllegalArgumentException(e);
}
}
}

View File

@ -5,15 +5,12 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
@ -37,7 +34,7 @@ import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
* Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed * Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed
* provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key. * provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key.
*/ */
public class RepeatedUseSignedPreKeyStore { public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
private final DynamoDbAsyncClient dynamoDbAsyncClient; private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName; private final String tableName;
@ -63,9 +60,6 @@ public class RepeatedUseSignedPreKeyStore {
private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey"); private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
private static final Counter INVALID_KEY_COUNTER =
Metrics.counter(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "invalidKey"));
public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
this.dynamoDbAsyncClient = dynamoDbAsyncClient; this.dynamoDbAsyncClient = dynamoDbAsyncClient;
this.tableName = tableName; this.tableName = tableName;
@ -81,7 +75,7 @@ public class RepeatedUseSignedPreKeyStore {
* *
* @return a future that completes once the key has been stored * @return a future that completes once the key has been stored
*/ */
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final K signedPreKey) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder() return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
@ -101,14 +95,14 @@ public class RepeatedUseSignedPreKeyStore {
* *
* @return a future that completes once all keys have been stored * @return a future that completes once all keys have been stored
*/ */
public CompletableFuture<Void> store(final UUID identifier, final Map<Long, SignedPreKey> signedPreKeysByDeviceId) { public CompletableFuture<Void> store(final UUID identifier, final Map<Long, K> signedPreKeysByDeviceId) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder() return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(signedPreKeysByDeviceId.entrySet().stream() .transactItems(signedPreKeysByDeviceId.entrySet().stream()
.map(entry -> { .map(entry -> {
final long deviceId = entry.getKey(); final long deviceId = entry.getKey();
final SignedPreKey signedPreKey = entry.getValue(); final K signedPreKey = entry.getValue();
return TransactWriteItem.builder() return TransactWriteItem.builder()
.put(Put.builder() .put(Put.builder()
@ -131,10 +125,10 @@ public class RepeatedUseSignedPreKeyStore {
* @return a future that yields an optional signed pre-key if one is available for the target device or empty if no * @return a future that yields an optional signed pre-key if one is available for the target device or empty if no
* key could be found for the target device * key could be found for the target device
*/ */
public CompletableFuture<Optional<SignedPreKey>> find(final UUID identifier, final long deviceId) { public CompletableFuture<Optional<K>> find(final UUID identifier, final long deviceId) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final CompletableFuture<Optional<SignedPreKey>> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder() final CompletableFuture<Optional<K>> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder()
.tableName(tableName) .tableName(tableName)
.key(getPrimaryKey(identifier, deviceId)) .key(getPrimaryKey(identifier, deviceId))
.consistentRead(true) .consistentRead(true)
@ -202,41 +196,21 @@ public class RepeatedUseSignedPreKeyStore {
.map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n())); .map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n()));
} }
private static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final long deviceId) { protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final long deviceId) {
return Map.of( return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID, getSortKey(deviceId)); KEY_DEVICE_ID, getSortKey(deviceId));
} }
private static AttributeValue getPartitionKey(final UUID accountUuid) { protected static AttributeValue getPartitionKey(final UUID accountUuid) {
return AttributeValues.fromUUID(accountUuid); return AttributeValues.fromUUID(accountUuid);
} }
private static AttributeValue getSortKey(final long deviceId) { protected static AttributeValue getSortKey(final long deviceId) {
return AttributeValues.fromLong(deviceId); return AttributeValues.fromLong(deviceId);
} }
private static Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final SignedPreKey signedPreKey) { protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final K signedPreKey);
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID, getSortKey(deviceId),
ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.getKeyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature()));
}
private static SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) { protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
final byte[] publicKeyBytes = item.get(ATTR_PUBLIC_KEY).b().asByteArray();
try {
new KEMPublicKey(publicKeyBytes);
} catch (final InvalidKeyException e) {
INVALID_KEY_COUNTER.increment();
}
return new SignedPreKey(
Long.parseLong(item.get(ATTR_KEY_ID).n()),
publicKeyBytes,
item.get(ATTR_SIGNATURE).b().asByteArray());
}
} }

View File

@ -5,46 +5,39 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<PreKey> { public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
private static final Counter INVALID_KEY_COUNTER =
Metrics.counter(MetricsUtil.name(SingleUseECPreKeyStore.class, "invalidKey"));
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName); super(dynamoDbAsyncClient, tableName);
} }
@Override @Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final PreKey preKey) { protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final ECPreKey preKey) {
return Map.of( return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey())); ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.serializedPublicKey()));
} }
@Override @Override
protected PreKey getPreKeyFromItem(final Map<String, AttributeValue> item) { protected ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY));
try { try {
new ECPublicKey(publicKey); return new ECPreKey(keyId, new ECPublicKey(publicKey));
} catch (final InvalidKeyException e) { } catch (final InvalidKeyException e) {
INVALID_KEY_COUNTER.increment(); // This should never happen since we're serializing keys directly from `ECPublicKey` instances on the way in
throw new IllegalArgumentException(e);
} }
return new PreKey(keyId, publicKey);
} }
} }

View File

@ -5,48 +5,41 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey; import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<SignedPreKey> { public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKey> {
private static final Counter INVALID_KEY_COUNTER =
Metrics.counter(MetricsUtil.name(SingleUseKEMPreKeyStore.class, "invalidKey"));
protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName); super(dynamoDbAsyncClient, tableName);
} }
@Override @Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final KEMSignedPreKey signedPreKey) {
return Map.of( return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.getKeyId()), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()),
ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()), ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.serializedPublicKey()),
ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature())); ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.signature()));
} }
@Override @Override
protected SignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) { protected KEMSignedPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); final byte[] publicKey = item.get(ATTR_PUBLIC_KEY).b().asByteArray();
final byte[] signature = extractByteArray(item.get(ATTR_SIGNATURE)); final byte[] signature = item.get(ATTR_SIGNATURE).b().asByteArray();
try { try {
new KEMPublicKey(publicKey); return new KEMSignedPreKey(keyId, new KEMPublicKey(publicKey), signature);
} catch (final InvalidKeyException e) { } catch (final InvalidKeyException e) {
INVALID_KEY_COUNTER.increment(); // This should never happen since we're serializing keys directly from `KEMPublicKey` instances on the way in
throw new IllegalArgumentException(e);
} }
return new SignedPreKey(keyId, publicKey, signature);
} }
} }

View File

@ -47,7 +47,7 @@ import software.amazon.awssdk.services.dynamodb.model.Select;
* the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party * the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead. * may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
*/ */
public abstract class SingleUsePreKeyStore<K extends PreKey> { public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
private final DynamoDbAsyncClient dynamoDbAsyncClient; private final DynamoDbAsyncClient dynamoDbAsyncClient;
private final String tableName; private final String tableName;

View File

@ -0,0 +1,52 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
public class ECPublicKeyAdapter {
public static class Serializer extends JsonSerializer<ECPublicKey> {
@Override
public void serialize(final ECPublicKey ecPublicKey,
final JsonGenerator jsonGenerator,
final SerializerProvider serializers) throws IOException {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(ecPublicKey.serialize()));
}
}
public static class Deserializer extends JsonDeserializer<ECPublicKey> {
@Override
public ECPublicKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException {
final byte[] ecPublicKeyBytes;
try {
ecPublicKeyBytes = Base64.getDecoder().decode(parser.getValueAsString());
} catch (final IllegalArgumentException e) {
throw new JsonParseException(parser, "Could not parse EC public key as a base64-encoded value", e);
}
try {
return new ECPublicKey(ecPublicKeyBytes);
} catch (final InvalidKeyException e) {
throw new JsonParseException(parser, "Could not interpret key bytes as an EC public key", e);
}
}
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import java.io.IOException;
import java.util.Base64;
public class KEMPublicKeyAdapter {
public static class Serializer extends JsonSerializer<KEMPublicKey> {
@Override
public void serialize(final KEMPublicKey kemPublicKey,
final JsonGenerator jsonGenerator,
final SerializerProvider serializers) throws IOException {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(kemPublicKey.serialize()));
}
}
public static class Deserializer extends JsonDeserializer<KEMPublicKey> {
@Override
public KEMPublicKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException {
final byte[] kemPublicKeyBytes;
try {
kemPublicKeyBytes = Base64.getDecoder().decode(parser.getValueAsString());
} catch (final IllegalArgumentException e) {
throw new JsonParseException(parser, "Could not parse KEM public key as a base64-encoded value", e);
}
try {
return new KEMPublicKey(kemPublicKeyBytes);
} catch (final InvalidKeyException e) {
throw new JsonParseException(parser, "Could not interpret key bytes as a KEM public key", e);
}
}
}
}

View File

@ -1,37 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.ElementType.TYPE_USE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import java.lang.annotation.Documented;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import javax.validation.Constraint;
import javax.validation.Payload;
@Target({FIELD, PARAMETER, TYPE_USE})
@Retention(RUNTIME)
@Constraint(validatedBy = {ValidPreKeyValidator.class})
@Documented
public @interface ValidPreKey {
public enum PreKeyType {
ECC,
KYBER
}
PreKeyType type();
String message() default "{org.whispersystems.textsecuregcm.util.ValidPreKey.message}";
Class<?>[] groups() default { };
Class<? extends Payload>[] payload() default { };
}

View File

@ -1,38 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.PreKey;
public class ValidPreKeyValidator implements ConstraintValidator<ValidPreKey, PreKey> {
private ValidPreKey.PreKeyType type;
@Override
public void initialize(ValidPreKey annotation) {
type = annotation.type();
}
@Override
public boolean isValid(PreKey value, ConstraintValidatorContext context) {
if (value == null) {
return true;
}
try {
switch (type) {
case ECC -> Curve.decodePoint(value.getPublicKey(), 0);
case KYBER -> new KEMPublicKey(value.getPublicKey());
}
} catch (IllegalArgumentException | InvalidKeyException e) {
return false;
}
return true;
}
}

View File

@ -74,6 +74,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
@ -84,7 +85,6 @@ import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
@ -194,7 +194,7 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
} }
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device(); final Device device = new Device();
device.setId(id); device.setId(id);
device.setRegistrationId(registrationId); device.setRegistrationId(registrationId);

View File

@ -55,10 +55,11 @@ import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
@ -418,10 +419,10 @@ class RegistrationControllerTest {
static Stream<Arguments> atomicAccountCreationConflictingChannel() { static Stream<Arguments> atomicAccountCreationConflictingChannel() {
final Optional<IdentityKey> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey; final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -507,10 +508,10 @@ class RegistrationControllerTest {
static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() { static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
final Optional<IdentityKey> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey; final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -620,10 +621,10 @@ class RegistrationControllerTest {
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final IdentityKey expectedAciIdentityKey, final IdentityKey expectedAciIdentityKey,
final IdentityKey expectedPniIdentityKey, final IdentityKey expectedPniIdentityKey,
final SignedPreKey expectedAciSignedPreKey, final ECSignedPreKey expectedAciSignedPreKey,
final SignedPreKey expectedPniSignedPreKey, final ECSignedPreKey expectedPniSignedPreKey,
final SignedPreKey expectedAciPqLastResortPreKey, final KEMSignedPreKey expectedAciPqLastResortPreKey,
final SignedPreKey expectedPniPqLastResortPreKey, final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<String> expectedApnsToken, final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken, final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) throws InterruptedException { final Optional<String> expectedGcmToken) throws InterruptedException {
@ -686,10 +687,10 @@ class RegistrationControllerTest {
private static Stream<Arguments> atomicAccountCreationSuccess() { private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey; final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
{ {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();

View File

@ -29,7 +29,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -149,17 +149,17 @@ class AccountsManagerChangeNumberIntegrationTest {
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
final int rotatedPniRegistrationId = 17; final int rotatedPniRegistrationId = 17;
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final SignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair); final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair);
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities()); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities());
final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>()); final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>());
account.getMasterDevice().orElseThrow().setSignedPreKey(new SignedPreKey()); account.getMasterDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map<Long, ECSignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);

View File

@ -55,7 +55,8 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@ -673,9 +674,10 @@ class AccountsManagerTest {
final String number = "+14152222222"; final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber( () -> accountsManager.changeNumber(
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any()); verify(accounts, never()).update(any());
@ -719,10 +721,10 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> newSignedKeys = Map.of( final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair), 1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair)); 2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of( final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202); final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
@ -768,14 +770,14 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Long, SignedPreKey> newSignedKeys = Map.of( Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair), 1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair)); 2L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202); Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
UUID oldUuid = account.getUuid(); UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier(); UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -810,10 +812,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> newSignedKeys = Map.of( final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair), 1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair)); 2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of( final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202); Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
@ -823,7 +825,7 @@ class AccountsManagerTest {
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());

View File

@ -1023,9 +1023,7 @@ class AccountsTest {
assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId()); assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId());
assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId()); assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId());
assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen()); assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen());
assertThat(resultDevice.getSignedPreKey().getPublicKey()).isEqualTo(expectingDevice.getSignedPreKey().getPublicKey()); assertThat(resultDevice.getSignedPreKey()).isEqualTo(expectingDevice.getSignedPreKey());
assertThat(resultDevice.getSignedPreKey().getKeyId()).isEqualTo(expectingDevice.getSignedPreKey().getKeyId());
assertThat(resultDevice.getSignedPreKey().getSignature()).isEqualTo(expectingDevice.getSignedPreKey().getSignature());
assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages()); assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages());
assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent()); assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent());
assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName()); assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName());

View File

@ -28,11 +28,15 @@ import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -106,8 +110,9 @@ public class ChangeNumberManagerTest {
void changeNumberSetPrimaryDevicePrekey() throws Exception { void changeNumberSetPrimaryDevicePrekey() throws Exception {
Account account = mock(Account.class); Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
var prekeys = Map.of(1L, new SignedPreKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
@ -133,8 +138,9 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
@ -176,9 +182,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
@ -218,9 +225,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
@ -258,8 +266,9 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
@ -297,9 +306,10 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
@ -344,7 +354,10 @@ public class ChangeNumberManagerTest {
new IncomingMessage(1, 2, 1, "foo"), new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo")); new IncomingMessage(1, 3, 1, "foo"));
final Map<Long, SignedPreKey> preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,
@ -374,7 +387,10 @@ public class ChangeNumberManagerTest {
new IncomingMessage(1, 2, 1, "foo"), new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo")); new IncomingMessage(1, 3, 1, "foo"));
final Map<Long, SignedPreKey> preKeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey(), 3L, new SignedPreKey()); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,

View File

@ -13,14 +13,14 @@ import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
class DeviceTest { class DeviceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId, void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId,
final SignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) { final ECSignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) {
final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis(); final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis();
@ -41,36 +41,36 @@ class DeviceTest {
// master fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled // master fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled
Arguments.of(true, false, null, null, null, Duration.ofDays(60), false), Arguments.of(true, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, null, null, null, Duration.ofDays(1), false), Arguments.of(true, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(true, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false), Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false), Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false), Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), true), Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false), Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false), Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), true), Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, true, null, null, null, Duration.ofDays(60), false), Arguments.of(true, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, true, null, null, null, Duration.ofDays(1), false), Arguments.of(true, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), true), Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, null, null, null, Duration.ofDays(60), false), Arguments.of(false, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, false, null, null, null, Duration.ofDays(1), false), Arguments.of(false, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, null, mock(SignedPreKey.class), Duration.ofDays(1), false), Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false), Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false), Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(60), false), Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", mock(SignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false), Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false), Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(60), false), Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, mock(SignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, true, null, null, null, Duration.ofDays(60), false), Arguments.of(false, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, true, null, null, null, Duration.ofDays(1), false), Arguments.of(false, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(60), false), Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, true, null, null, mock(SignedPreKey.class), Duration.ofDays(1), true) Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true)
); );
} }
} }

View File

@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import java.security.SecureRandom;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -21,8 +20,9 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -37,6 +37,8 @@ class KeysManagerTest {
private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L; private static final long DEVICE_ID = 1L;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@BeforeEach @BeforeEach
void setup() { void setup() {
keysManager = new KeysManager( keysManager = new KeysManager(
@ -62,17 +64,17 @@ class KeysManagerTest {
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect"); "Repeatedly storing same key should have no effect");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys"); "Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001)); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestKEMSignedPreKey(1001));
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys"); "Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId());
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
@ -80,7 +82,7 @@ class KeysManagerTest {
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys"); "Uploading new EC prekeys should have no effect on PQ prekeys");
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device"); "Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
@ -88,13 +90,12 @@ class KeysManagerTest {
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(4), generateTestPreKey(5)), List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)), List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(1002));
generateTestSignedPreKey(1002));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device"); "Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device"); "Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(), assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().keyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device"); "Uploading new last-resort key should overwrite prior last-resort key for the account/device");
} }
@ -102,10 +103,10 @@ class KeysManagerTest {
void testTakeAccountAndDeviceId() { void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = generateTestPreKey(1); final ECPreKey preKey = generateTestPreKey(1);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); final Optional<ECPreKey> takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey); assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
} }
@ -114,9 +115,9 @@ class KeysManagerTest {
void testTakePQ() { void testTakePQ() {
assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = generateTestSignedPreKey(1); final KEMSignedPreKey preKey1 = generateTestKEMSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2); final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001); final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001);
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
@ -138,7 +139,7 @@ class KeysManagerTest {
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null); keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null);
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
} }
@ -147,13 +148,11 @@ class KeysManagerTest {
void testDeleteByAccount() { void testDeleteByAccount() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)), List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5));
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)), List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)), List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8));
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@ -176,13 +175,11 @@ class KeysManagerTest {
void testDeleteByAccountAndDevice() { void testDeleteByAccountAndDevice() {
keysManager.store(ACCOUNT_UUID, DEVICE_ID, keysManager.store(ACCOUNT_UUID, DEVICE_ID,
List.of(generateTestPreKey(1), generateTestPreKey(2)), List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), generateTestKEMSignedPreKey(5));
generateTestSignedPreKey(5));
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(generateTestPreKey(6)), List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)), List.of(generateTestKEMSignedPreKey(7)), generateTestKEMSignedPreKey(8));
generateTestSignedPreKey(8));
assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@ -211,17 +208,17 @@ class KeysManagerTest {
ACCOUNT_UUID, ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair)));
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId()); assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId()); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keysManager.storePqLastResort( keysManager.storePqLastResort(
ACCOUNT_UUID, ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair)));
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates"); assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone"); assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().keyId(), "storing new last-resort keys should overwrite old ones");
} }
@Test @Test
@ -237,21 +234,15 @@ class KeysManagerTest {
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID))); Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID)));
} }
private static PreKey generateTestPreKey(final long keyId) { private static ECPreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32]; return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
new SecureRandom().nextBytes(key);
return new PreKey(keyId, key);
} }
private static SignedPreKey generateTestSignedPreKey(final long keyId) { private static ECSignedPreKey generateTestECSignedPreKey(final long keyId) {
final byte[] key = new byte[32]; return KeysHelper.signedECPreKey(keyId, IDENTITY_KEY_PAIR);
final byte[] signature = new byte[32]; }
final SecureRandom secureRandom = new SecureRandom(); private static KEMSignedPreKey generateTestKEMSignedPreKey(final long keyId) {
secureRandom.nextBytes(key); return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, key, signature);
} }
} }

View File

@ -0,0 +1,44 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import static org.junit.jupiter.api.Assertions.*;
class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest<KEMSignedPreKey> {
private RepeatedUseKEMSignedPreKeyStore keyStore;
private int currentKeyId = 1;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@BeforeEach
void setUp() {
keyStore = new RepeatedUseKEMSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Override
protected RepeatedUseSignedPreKeyStore<KEMSignedPreKey> getKeyStore() {
return keyStore;
}
@Override
protected KEMSignedPreKey generateSignedPreKey() {
return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR);
}
}

View File

@ -5,59 +5,31 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.reactivestreams.Subscriber;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.Test;
import java.util.concurrent.CompletionException; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import static org.junit.jupiter.api.Assertions.*; abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class RepeatedUseSignedPreKeyStoreTest { protected abstract RepeatedUseSignedPreKeyStore<K> getKeyStore();
private RepeatedUseSignedPreKeyStore keys; protected abstract K generateSignedPreKey();
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION =
new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS);
@BeforeEach
void setUp() {
keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
}
@Test @Test
void storeFind() { void storeFind() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join()); assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join());
{ {
final UUID identifier = UUID.randomUUID(); final UUID identifier = UUID.randomUUID();
final long deviceId = 1; final long deviceId = 1;
final SignedPreKey signedPreKey = generateSignedPreKey(); final K signedPreKey = generateSignedPreKey();
assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join());
assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join()); assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join());
@ -65,7 +37,7 @@ class RepeatedUseSignedPreKeyStoreTest {
{ {
final UUID identifier = UUID.randomUUID(); final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of( final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(), 1L, generateSignedPreKey(),
2L, generateSignedPreKey() 2L, generateSignedPreKey()
); );
@ -78,11 +50,13 @@ class RepeatedUseSignedPreKeyStoreTest {
@Test @Test
void delete() { void delete() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
{ {
final UUID identifier = UUID.randomUUID(); final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of( final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(), 1L, generateSignedPreKey(),
2L, generateSignedPreKey() 2L, generateSignedPreKey()
); );
@ -96,7 +70,7 @@ class RepeatedUseSignedPreKeyStoreTest {
{ {
final UUID identifier = UUID.randomUUID(); final UUID identifier = UUID.randomUUID();
final Map<Long, SignedPreKey> signedPreKeys = Map.of( final Map<Long, K> signedPreKeys = Map.of(
1L, generateSignedPreKey(), 1L, generateSignedPreKey(),
2L, generateSignedPreKey() 2L, generateSignedPreKey()
); );
@ -108,42 +82,4 @@ class RepeatedUseSignedPreKeyStoreTest {
assertEquals(Optional.empty(), keys.find(identifier, 2).join()); assertEquals(Optional.empty(), keys.find(identifier, 2).join());
} }
} }
@Test
void deleteWithError() {
final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class);
final QueryPublisher queryPublisher = mock(QueryPublisher.class);
final SdkPublisher<Map<String, AttributeValue>> itemPublisher = new SdkPublisher<Map<String, AttributeValue>>() {
final Flux<Map<String, AttributeValue>> items = Flux.just(
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)),
Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2)));
@Override
public void subscribe(final Subscriber<? super Map<String, AttributeValue>> subscriber) {
items.subscribe(subscriber);
}
};
when(queryPublisher.items()).thenReturn(itemPublisher);
when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher);
final Exception deleteItemException = new IllegalArgumentException("OH NO");
when(mockClient.deleteItem(any(DeleteItemRequest.class)))
.thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build()))
.thenReturn(CompletableFuture.failedFuture(deleteItemException));
final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient,
DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName());
final CompletionException completionException =
assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join());
assertEquals(deleteItemException, completionException.getCause());
}
private static SignedPreKey generateSignedPreKey() {
return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR);
}
} }

View File

@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> { class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<ECPreKey> {
private SingleUseECPreKeyStore preKeyStore; private SingleUseECPreKeyStore preKeyStore;
@ -24,12 +24,12 @@ class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest<PreKey> {
} }
@Override @Override
protected SingleUsePreKeyStore<PreKey> getPreKeyStore() { protected SingleUsePreKeyStore<ECPreKey> getPreKeyStore() {
return preKeyStore; return preKeyStore;
} }
@Override @Override
protected PreKey generatePreKey(final long keyId) { protected ECPreKey generatePreKey(final long keyId) {
return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize()); return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey());
} }
} }

View File

@ -9,10 +9,10 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey> { class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<KEMSignedPreKey> {
private SingleUseKEMPreKeyStore preKeyStore; private SingleUseKEMPreKeyStore preKeyStore;
@ -28,12 +28,12 @@ class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest<SignedPreKey>
} }
@Override @Override
protected SingleUsePreKeyStore<SignedPreKey> getPreKeyStore() { protected SingleUsePreKeyStore<KEMSignedPreKey> getPreKeyStore() {
return preKeyStore; return preKeyStore;
} }
@Override @Override
protected SignedPreKey generatePreKey(final long keyId) { protected KEMSignedPreKey generatePreKey(final long keyId) {
return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR); return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR);
} }
} }

View File

@ -24,7 +24,7 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
abstract class SingleUsePreKeyStoreTest<K extends PreKey> { abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
private static final int KEY_COUNT = 100; private static final int KEY_COUNT = 100;

View File

@ -54,9 +54,10 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceResponse; import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
@ -265,10 +266,10 @@ class DeviceControllerTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey; final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey; final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -351,10 +352,10 @@ class DeviceControllerTest {
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
final Optional<SignedPreKey> aciSignedPreKey; final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<SignedPreKey> pniPqLastResortPreKey; final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -395,10 +396,10 @@ class DeviceControllerTest {
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey, void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final Optional<SignedPreKey> aciSignedPreKey, final Optional<ECSignedPreKey> aciSignedPreKey,
final Optional<SignedPreKey> pniSignedPreKey, final Optional<ECSignedPreKey> pniSignedPreKey,
final Optional<SignedPreKey> aciPqLastResortPreKey, final Optional<KEMSignedPreKey> aciPqLastResortPreKey,
final Optional<SignedPreKey> pniPqLastResortPreKey) { final Optional<KEMSignedPreKey> pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@ -435,10 +436,10 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final Optional<SignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); final Optional<ECSignedPreKey> aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
final Optional<SignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); final Optional<ECSignedPreKey> pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Optional<SignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); final Optional<KEMSignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<SignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final Optional<KEMSignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
@ -455,10 +456,10 @@ class DeviceControllerTest {
@MethodSource @MethodSource
void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey, void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final SignedPreKey aciSignedPreKey, final ECSignedPreKey aciSignedPreKey,
final SignedPreKey pniSignedPreKey, final ECSignedPreKey pniSignedPreKey,
final SignedPreKey aciPqLastResortPreKey, final KEMSignedPreKey aciPqLastResortPreKey,
final SignedPreKey pniPqLastResortPreKey) { final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@ -495,25 +496,31 @@ class DeviceControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final SignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair); final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
final SignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair); final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of( return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, ecSignedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, signedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, ecSignedPreKeyWithBadSignature(pniSignedPreKey), aciPqLastResortPreKey, pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, signedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, kemSignedPreKeyWithBadSignature(aciPqLastResortPreKey), pniPqLastResortPreKey),
Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, signedPreKeyWithBadSignature(pniPqLastResortPreKey)) Arguments.of(aciIdentityKey, pniIdentityKey, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, kemSignedPreKeyWithBadSignature(pniPqLastResortPreKey))
); );
} }
private static SignedPreKey signedPreKeyWithBadSignature(final SignedPreKey signedPreKey) { private static ECSignedPreKey ecSignedPreKeyWithBadSignature(final ECSignedPreKey signedPreKey) {
return new SignedPreKey(signedPreKey.getKeyId(), return new ECSignedPreKey(signedPreKey.keyId(),
signedPreKey.getPublicKey(), signedPreKey.publicKey(),
"incorrect-signature".getBytes(StandardCharsets.UTF_8));
}
private static KEMSignedPreKey kemSignedPreKeyWithBadSignature(final KEMSignedPreKey signedPreKey) {
return new KEMSignedPreKey(signedPreKey.keyId(),
signedPreKey.publicKey(),
"incorrect-signature".getBytes(StandardCharsets.UTF_8)); "incorrect-signature".getBytes(StandardCharsets.UTF_8));
} }

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.tests.controllers; package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
@ -20,6 +21,8 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -47,6 +50,9 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccou
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
@ -63,6 +69,7 @@ import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest { class KeysControllerTest {
@ -86,27 +93,27 @@ class KeysControllerTest {
private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()); private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey());
private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); private final ECPreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234);
private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); private final ECPreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667);
private final PreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334); private final ECPreKey SAMPLE_KEY3 = KeysHelper.ecPreKey(334);
private final PreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336); private final ECPreKey SAMPLE_KEY4 = KeysHelper.ecPreKey(336);
private final PreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777); private final ECPreKey SAMPLE_KEY_PNI = KeysHelper.ecPreKey(7777);
private final SignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair()); private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair());
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedECPreKey(1111, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedECPreKey(2222, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedECPreKey(3333, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedECPreKey(4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedECPreKey(5555, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR); private final ECSignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedECPreKey(6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR); private final ECSignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR); private final ECSignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR);
private final static KeysManager KEYS = mock(KeysManager.class ); private final static KeysManager KEYS = mock(KeysManager.class );
private final static AccountsManager accounts = mock(AccountsManager.class ); private final static AccountsManager accounts = mock(AccountsManager.class );
@ -127,6 +134,42 @@ class KeysControllerTest {
private Device sampleDevice; private Device sampleDevice;
private record WeaklyTypedPreKey(long keyId,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] publicKey) {
static WeaklyTypedPreKey fromPreKey(final PreKey<?> preKey) {
return new WeaklyTypedPreKey(preKey.keyId(), preKey.serializedPublicKey());
}
}
private record WeaklyTypedSignedPreKey(long keyId,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] publicKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] signature) {
static WeaklyTypedSignedPreKey fromSignedPreKey(final SignedPreKey<?> signedPreKey) {
return new WeaklyTypedSignedPreKey(signedPreKey.keyId(), signedPreKey.serializedPublicKey(), signedPreKey.signature());
}
}
private record WeaklyTypedPreKeyState(List<WeaklyTypedPreKey> preKeys,
WeaklyTypedSignedPreKey signedPreKey,
List<WeaklyTypedSignedPreKey> pqPreKeys,
WeaklyTypedSignedPreKey pqLastResortPreKey,
@JsonSerialize(using = ByteArrayAdapter.Serializing.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class)
byte[] identityKey) {
}
@BeforeEach @BeforeEach
void setup() { void setup() {
sampleDevice = mock(Device.class); sampleDevice = mock(Device.class);
@ -228,30 +271,30 @@ class KeysControllerTest {
@Test @Test
void getSignedPreKeyV2() { void getSignedPreKeyV2() {
SignedPreKey result = resources.getJerseyTest() ECSignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class); .get(ECSignedPreKey.class);
assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result); assertEquals(VALID_DEVICE_SIGNED_KEY, result);
} }
@Test @Test
void getPhoneNumberIdentifierSignedPreKeyV2() { void getPhoneNumberIdentifierSignedPreKeyV2() {
SignedPreKey result = resources.getJerseyTest() ECSignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.queryParam("identity", "pni") .queryParam("identity", "pni")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class); .get(ECSignedPreKey.class);
assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result); assertEquals(VALID_DEVICE_PNI_SIGNED_KEY, result);
} }
@Test @Test
void putSignedPreKeyV2() { void putSignedPreKeyV2() {
SignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR); ECSignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.request() .request()
@ -267,7 +310,7 @@ class KeysControllerTest {
@Test @Test
void putPhoneNumberIdentitySignedPreKeyV2() { void putPhoneNumberIdentitySignedPreKeyV2() {
final SignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR); final ECSignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
@ -285,7 +328,7 @@ class KeysControllerTest {
@Test @Test
void disabledPutSignedPreKeyV2() { void disabledPutSignedPreKeyV2() {
SignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR); ECSignedPreKey test = KeysHelper.signedECPreKey(9999, IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.request() .request()
@ -305,10 +348,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
@ -316,7 +359,7 @@ class KeysControllerTest {
@Test @Test
void validSingleRequestPqTestNoPqKeysV2() { void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.<SignedPreKey>empty()); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty());
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@ -327,10 +370,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1);
@ -348,10 +391,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1);
@ -368,10 +411,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1); verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
@ -388,10 +431,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1); verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1); verify(KEYS).takePQ(EXISTS_PNI, 1);
@ -410,10 +453,10 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1); verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
@ -445,9 +488,9 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1); verify(KEYS).takePQ(EXISTS_UUID, 1);
@ -510,14 +553,14 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey(); ECPreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId(); long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId(); long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey); assertEquals(SAMPLE_KEY, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1); assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey(); signedPreKey = results.getDevice(2).getSignedPreKey();
@ -525,9 +568,9 @@ class KeysControllerTest {
registrationId = results.getDevice(2).getRegistrationId(); registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId(); deviceId = results.getDevice(2).getDeviceId();
assertKeysMatch(SAMPLE_KEY2, preKey); assertEquals(SAMPLE_KEY2, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2); assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey(); signedPreKey = results.getDevice(4).getSignedPreKey();
@ -535,7 +578,7 @@ class KeysControllerTest {
registrationId = results.getDevice(4).getRegistrationId(); registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId(); deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey); assertEquals(SAMPLE_KEY4, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull(); assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4); assertThat(deviceId).isEqualTo(4);
@ -554,7 +597,7 @@ class KeysControllerTest {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2)); when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3)); when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.<SignedPreKey>empty()); when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty());
PreKeyResponse results = resources.getJerseyTest() PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@ -566,16 +609,16 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey(); ECPreKey preKey = results.getDevice(1).getPreKey();
SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey(); KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId(); long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId(); long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey); assertEquals(SAMPLE_KEY, preKey);
assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey); assertEquals(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1); assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey(); signedPreKey = results.getDevice(2).getSignedPreKey();
@ -585,9 +628,9 @@ class KeysControllerTest {
deviceId = results.getDevice(2).getDeviceId(); deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey).isNull(); assertThat(preKey).isNull();
assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey); assertEquals(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2); assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey(); signedPreKey = results.getDevice(4).getSignedPreKey();
@ -596,7 +639,7 @@ class KeysControllerTest {
registrationId = results.getDevice(4).getRegistrationId(); registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId(); deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey); assertEquals(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull(); assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull(); assertThat(signedPreKey).isNull();
@ -656,9 +699,9 @@ class KeysControllerTest {
@Test @Test
void putKeysTestV2() { void putKeysTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@ -672,7 +715,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey); assertThat(listCaptor.getValue()).containsExactly(preKey);
@ -684,11 +727,11 @@ class KeysControllerTest {
@Test @Test
void putKeysPqTestV2() { void putKeysPqTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@ -702,8 +745,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(ecCaptor.getValue()).containsExactly(preKey);
@ -718,8 +761,9 @@ class KeysControllerTest {
void putKeysStructurallyInvalidSignedECKey() { void putKeysStructurallyInvalidSignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair); final KEMSignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null); final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, WeaklyTypedSignedPreKey.fromSignedPreKey(wrongPreKey), null, null, identityKey.serialize());
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -728,15 +772,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test @Test
void putKeysStructurallyInvalidUnsignedECKey() { void putKeysStructurallyInvalidUnsignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes()); final WeaklyTypedPreKey wrongPreKey = new WeaklyTypedPreKey(1, "cluck cluck i'm a parrot".getBytes());
final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null); final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(List.of(wrongPreKey), null, null, null, identityKey.serialize());
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -745,15 +790,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test @Test
void putKeysStructurallyInvalidPQOneTimeKey() { void putKeysStructurallyInvalidPQOneTimeKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair));
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null); final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, null, List.of(wrongPreKey), null, identityKey.serialize());
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -762,15 +808,16 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test @Test
void putKeysStructurallyInvalidPQLastResortKey() { void putKeysStructurallyInvalidPQLastResortKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final WeaklyTypedSignedPreKey wrongPreKey = WeaklyTypedSignedPreKey.fromSignedPreKey(KeysHelper.signedECPreKey(1, identityKeyPair));
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey); final WeaklyTypedPreKeyState preKeyState =
new WeaklyTypedPreKeyState(null, null, null, wrongPreKey, identityKey.serialize());
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -779,14 +826,14 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(400);
} }
@Test @Test
void putKeysByPhoneNumberIdentifierTestV2() { void putKeysByPhoneNumberIdentifierTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@ -801,7 +848,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull()); verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey); assertThat(listCaptor.getValue()).containsExactly(preKey);
@ -813,11 +860,11 @@ class KeysControllerTest {
@Test @Test
void putKeysByPhoneNumberIdentifierPqTestV2() { void putKeysByPhoneNumberIdentifierPqTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final KEMSignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final KEMSignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@ -832,8 +879,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(ecCaptor.getValue()).containsExactly(preKey);
@ -846,7 +893,7 @@ class KeysControllerTest {
@Test @Test
void putPrekeyWithInvalidSignature() { void putPrekeyWithInvalidSignature() {
final SignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair()); final ECSignedPreKey badSignedPreKey = KeysHelper.signedECPreKey(1, Curve.generateKeyPair());
PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, badSignedPreKey, List.of()); PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, badSignedPreKey, List.of());
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -861,9 +908,9 @@ class KeysControllerTest {
@Test @Test
void disabledPutKeysTestV2() { void disabledPutKeysTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@ -877,13 +924,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue(); List<ECPreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1); assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); assertThat(capturedList.get(0).keyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey()).isEqualTo(preKey.getPublicKey()); assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey());
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
@ -892,10 +939,10 @@ class KeysControllerTest {
@Test @Test
void putIdentityKeyNonPrimary() { void putIdentityKeyNonPrimary() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final ECPreKey preKey = KeysHelper.ecPreKey(31337);
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, IDENTITY_KEY_PAIR);
List<PreKey> preKeys = List.of(preKey); List<ECPreKey> preKeys = List.of(preKey);
PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, signedPreKey, preKeys); PreKeyState preKeyState = new PreKeyState(IDENTITY_KEY, signedPreKey, preKeys);
@ -908,13 +955,4 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
} }
private void assertKeysMatch(PreKey expected, PreKey actual) {
assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId());
assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey());
if (expected instanceof final SignedPreKey signedExpected) {
final SignedPreKey signedActual = (SignedPreKey) actual;
assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature());
}
}
} }

View File

@ -12,7 +12,8 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson;
import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import java.util.Base64; import java.util.Base64;
@ -22,7 +23,7 @@ class PreKeyTest {
@Test @Test
void serializeToJSONV2() throws Exception { void serializeToJSONV2() throws Exception {
PreKey preKey = new PreKey(1234, PUBLIC_KEY); ECPreKey preKey = new ECPreKey(1234, new ECPublicKey(PUBLIC_KEY));
assertThat("PreKeyV2 Serialization works", assertThat("PreKeyV2 Serialization works",
asJson(preKey), asJson(preKey),

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.tests.util;
import java.util.Random; import java.util.Random;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;

View File

@ -7,26 +7,29 @@ package org.whispersystems.textsecuregcm.tests.util;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.kem.KEMKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyType; import org.signal.libsignal.protocol.kem.KEMKeyType;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.signal.libsignal.protocol.kem.KEMPublicKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
public final class KeysHelper { public final class KeysHelper {
public static PreKey ecPreKey(final long id) { public static ECPreKey ecPreKey(final long id) {
return new PreKey(id, Curve.generateKeyPair().getPublicKey().serialize()); return new ECPreKey(id, Curve.generateKeyPair().getPublicKey());
} }
public static SignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) { public static ECSignedPreKey signedECPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = Curve.generateKeyPair().getPublicKey().serialize(); final ECPublicKey pubKey = Curve.generateKeyPair().getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new SignedPreKey(id, pubKey, sig); return new ECSignedPreKey(id, pubKey, sig);
} }
public static SignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) { public static KEMSignedPreKey signedKEMPreKey(long id, final ECKeyPair identityKeyPair) {
final byte[] pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey().serialize(); final KEMPublicKey pubKey = KEMKeyPair.generate(KEMKeyType.KYBER_1024).getPublicKey();
final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey); final byte[] sig = identityKeyPair.getPrivateKey().calculateSignature(pubKey.serialize());
return new SignedPreKey(id, pubKey, sig); return new KEMSignedPreKey(id, pubKey, sig);
} }
} }