diff --git a/integration-tests/src/main/java/org/signal/integration/Operations.java b/integration-tests/src/main/java/org/signal/integration/Operations.java index e59f64d64..140eac416 100644 --- a/integration-tests/src/main/java/org/signal/integration/Operations.java +++ b/integration-tests/src/main/java/org/signal/integration/Operations.java @@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.tuple.Pair; import org.signal.integration.config.Config; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyPair; @@ -109,8 +110,8 @@ public final class Operations { registrationPassword, accountAttributes, true, - Optional.of(aciIdentityKeyPair.getPublicKey().serialize()), - Optional.of(pniIdentityKeyPair.getPublicKey().serialize()), + Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())), + Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())), Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)), Optional.of(generateSignedECPreKey(2, pniIdentityKeyPair)), Optional.of(generateSignedKEMPreKey(3, aciIdentityKeyPair)), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java index a700b8a4c..066ca5839 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java @@ -34,7 +34,7 @@ public class CertificateGenerator { SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder() .setSenderDevice(Math.toIntExact(device.getId())) .setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays)) - .setIdentityKey(ByteString.copyFrom(account.getIdentityKey())) + .setIdentityKey(ByteString.copyFrom(account.getIdentityKey().serialize())) .setSigner(serverCertificate) .setSenderUuid(account.getUuid().toString()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java index d73204202..2f28923c9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java @@ -32,7 +32,6 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; -import org.apache.commons.lang3.ArrayUtils; import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations; import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse; import org.signal.libsignal.zkgroup.GenericServerSecretParams; @@ -75,7 +74,7 @@ public class CertificateController { @QueryParam("includeE164") @DefaultValue("true") boolean includeE164) throws InvalidKeyException { - if (ArrayUtils.isEmpty(auth.getAccount().getIdentityKey())) { + if (auth.getAccount().getIdentityKey() == null) { throw new WebApplicationException(Response.Status.BAD_REQUEST); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index f4ed5eb0b..6de5d9993 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -11,14 +11,14 @@ import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; -import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import javax.validation.Valid; @@ -35,8 +35,7 @@ import javax.ws.rs.QueryParam; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; - -import org.apache.commons.lang3.ArrayUtils; +import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; @@ -116,11 +115,11 @@ public class KeysController { updateAccount = true; } - final byte[] oldIdentityKey = usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); - if (!Arrays.equals(preKeys.getIdentityKey(), oldIdentityKey)) { + final IdentityKey oldIdentityKey = usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); + if (!Objects.equals(preKeys.getIdentityKey(), oldIdentityKey)) { updateAccount = true; - final boolean hasIdentityKey = ArrayUtils.isNotEmpty(oldIdentityKey); + final boolean hasIdentityKey = oldIdentityKey != null; final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)) .and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey)) .and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci"); @@ -221,7 +220,7 @@ public class KeysController { } } - final byte[] identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); + final IdentityKey identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); if (responseItems.isEmpty()) { return Response.status(404).build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index f3cc67d2a..bf86a9a87 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -64,6 +64,7 @@ import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.apache.commons.lang3.StringUtils; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredentialResponse; @@ -383,14 +384,14 @@ public class ProfileController { if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) { return; } - final byte[] identityKeyBytes = + final IdentityKey identityKey = usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); md.reset(); - byte[] digest = md.digest(identityKeyBytes); + byte[] digest = md.digest(identityKey.serialize()); byte[] fingerprint = Util.truncate(digest, 4); if (!Arrays.equals(fingerprint, element.fingerprint())) { - responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKeyBytes)); + responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKey)); } }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java index 580211c83..5a94fc7cb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BaseProfileResponse.java @@ -8,7 +8,9 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import java.util.List; import java.util.UUID; @@ -16,9 +18,9 @@ import java.util.UUID; public class BaseProfileResponse { @JsonProperty - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - private byte[] identityKey; + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + private IdentityKey identityKey; @JsonProperty private String unidentifiedAccess; @@ -38,7 +40,7 @@ public class BaseProfileResponse { public BaseProfileResponse() { } - public BaseProfileResponse(final byte[] identityKey, + public BaseProfileResponse(final IdentityKey identityKey, final String unidentifiedAccess, final boolean unrestrictedUnidentifiedAccess, final UserCapabilities capabilities, @@ -53,7 +55,7 @@ public class BaseProfileResponse { this.uuid = uuid; } - public byte[] getIdentityKey() { + public IdentityKey getIdentityKey() { return identityKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java index becb15f73..442b73ab6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/BatchIdentityCheckResponse.java @@ -11,13 +11,25 @@ import java.util.UUID; import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.NotNull; -import org.whispersystems.textsecuregcm.util.ExactlySize; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; public record BatchIdentityCheckResponse(@Valid List elements) { - public record Element(@Deprecated @JsonInclude(JsonInclude.Include.NON_EMPTY) @Nullable UUID aci, - @JsonInclude(JsonInclude.Include.NON_EMPTY) @Nullable UUID uuid, - @NotNull @ExactlySize(33) byte[] identityKey) { + public record Element(@Deprecated + @JsonInclude(JsonInclude.Include.NON_EMPTY) + @Nullable UUID aci, + + @JsonInclude(JsonInclude.Include.NON_EMPTY) + @Nullable UUID uuid, + + @NotNull + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + IdentityKey identityKey) { public Element { if (aci == null && uuid == null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index c54bfa9b8..27d75e22e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; import java.util.ArrayList; import java.util.List; @@ -15,9 +16,10 @@ import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotBlank; -import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotNull; +import org.signal.libsignal.protocol.IdentityKey; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; @@ -39,8 +41,9 @@ public record ChangeNumberRequest( @JsonProperty("reglock") @Nullable String registrationLock, @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - @NotEmpty byte[] pniIdentityKey, + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + @NotNull IdentityKey pniIdentityKey, @Schema(description=""" A list of synchronization messages to send to companion devices to supply the private keysManager diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index d8f771c1d..6a922728d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -7,17 +7,18 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import io.swagger.v3.oas.annotations.media.Schema; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; - import java.util.ArrayList; import java.util.List; import java.util.Map; -import javax.validation.constraints.AssertTrue; import javax.annotation.Nullable; import javax.validation.Valid; +import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotNull; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; @@ -32,8 +33,9 @@ public record ChangePhoneNumberRequest( @JsonProperty("reglock") @Nullable String registrationLock, @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - @Nullable byte[] pniIdentityKey, + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + @Nullable IdentityKey pniIdentityKey, @Schema(description=""" A list of synchronization messages to send to companion devices to supply the private keysManager diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index 58d048bcb..d02ff8d2b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -5,27 +5,24 @@ package org.whispersystems.textsecuregcm.entities; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.swagger.v3.oas.annotations.media.Schema; import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.NotBlank; -import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotNull; - -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; - -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record PhoneNumberIdentityKeyDistributionRequest( - @NotEmpty - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) + @NotNull + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) @Schema(description="the new identity key for this account's phone-number identity") - byte[] pniIdentityKey, + IdentityKey pniIdentityKey, @NotNull @Valid diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java index 47349b7a2..c818746ec 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java @@ -6,19 +6,21 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.annotations.VisibleForTesting; import io.swagger.v3.oas.annotations.media.Schema; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; - import java.util.List; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; public class PreKeyResponse { @JsonProperty - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) @Schema(description="the public identity key for the requested identity") - private byte[] identityKey; + private IdentityKey identityKey; @JsonProperty @Schema(description="information about each requested device") @@ -26,13 +28,13 @@ public class PreKeyResponse { public PreKeyResponse() {} - public PreKeyResponse(byte[] identityKey, List devices) { + public PreKeyResponse(IdentityKey identityKey, List devices) { this.identityKey = identityKey; - this.devices = devices; + this.devices = devices; } @VisibleForTesting - public byte[] getIdentityKey() { + public IdentityKey getIdentityKey() { return identityKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java index fd7fb62b6..0eaebe835 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java @@ -9,27 +9,23 @@ import static com.codahale.metrics.MetricRegistry.name; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.util.Collection; -import org.signal.libsignal.protocol.InvalidKeyException; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.signal.libsignal.protocol.IdentityKey; public abstract class PreKeySignatureValidator { public static final Counter INVALID_SIGNATURE_COUNTER = Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")); - public static boolean validatePreKeySignatures(final byte[] identityKeyBytes, final Collection spks) { + public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection spks) { try { - final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0); - final boolean success = spks.stream() - .allMatch(spk -> identityKey.verifySignature(spk.getPublicKey(), spk.getSignature())); + .allMatch(spk -> identityKey.getPublicKey().verifySignature(spk.getPublicKey(), spk.getSignature())); if (!success) { INVALID_SIGNATURE_COUNTER.increment(); } return success; - } catch (IllegalArgumentException | InvalidKeyException e) { + } catch (final IllegalArgumentException e) { INVALID_SIGNATURE_COUNTER.increment(); return false; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java index 41756c6bc..d1479bd06 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java @@ -6,16 +6,16 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.annotations.VisibleForTesting; import io.swagger.v3.oas.annotations.media.Schema; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; - import java.util.ArrayList; import java.util.List; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.NotEmpty; import javax.validation.constraints.NotNull; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; @@ -55,24 +55,24 @@ public class PreKeyState { private SignedPreKey pqLastResortPreKey; @JsonProperty - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - @NotEmpty + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) @NotNull @Schema(description="Required. " + "The public identity key for this identity (account or phone-number identity). " + "If this device is not the primary device for the account, " + "must match the existing stored identity key for this identity.") - private byte[] identityKey; + private IdentityKey identityKey; public PreKeyState() {} @VisibleForTesting - public PreKeyState(byte[] identityKey, SignedPreKey signedPreKey, List keys) { + public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List keys) { this(identityKey, signedPreKey, keys, null, null); } @VisibleForTesting - public PreKeyState(byte[] identityKey, SignedPreKey signedPreKey, List keys, List pqKeys, SignedPreKey pqLastResortKey) { + public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List keys, List pqKeys, SignedPreKey pqLastResortKey) { this.identityKey = identityKey; this.signedPreKey = signedPreKey; this.preKeys = keys; @@ -96,7 +96,7 @@ public class PreKeyState { return pqLastResortPreKey; } - public byte[] getIdentityKey() { + public IdentityKey getIdentityKey() { return identityKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java index a377b592c..bd21f5367 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -12,16 +12,16 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.google.common.annotations.VisibleForTesting; import io.swagger.v3.oas.annotations.media.Schema; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; -import org.whispersystems.textsecuregcm.util.OptionalBase64ByteArrayDeserializer; -import org.whispersystems.textsecuregcm.util.ValidPreKey; -import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; - +import java.util.List; +import java.util.Optional; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotNull; -import java.util.List; -import java.util.Optional; +import org.signal.libsignal.protocol.IdentityKey; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; +import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter; +import org.whispersystems.textsecuregcm.util.ValidPreKey; +import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ The ID of an existing verification session as it appears in a verification session @@ -57,16 +57,18 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT provided, an account will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - @JsonDeserialize(using = OptionalBase64ByteArrayDeserializer.class) - Optional aciIdentityKey, + @JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class) + Optional aciIdentityKey, @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ The PNI-associated identity key for the account, encoded as a base64 string. If provided, an account will be created "atomically," and all other properties needed for atomic account creation must also be present. """) - @JsonDeserialize(using = OptionalBase64ByteArrayDeserializer.class) - Optional pniIdentityKey, + @JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class) + Optional pniIdentityKey, @JsonUnwrapped @JsonProperty(access = JsonProperty.Access.READ_ONLY) @@ -78,8 +80,8 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT @JsonProperty("recoveryPassword") byte[] recoveryPassword, @JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, - @JsonProperty("aciIdentityKey") Optional aciIdentityKey, - @JsonProperty("pniIdentityKey") Optional pniIdentityKey, + @JsonProperty("aciIdentityKey") Optional aciIdentityKey, + @JsonProperty("pniIdentityKey") Optional pniIdentityKey, @JsonProperty("aciSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> aciPqLastResortPreKey, @@ -103,7 +105,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - private static boolean validatePreKeySignature(final Optional maybeIdentityKey, + private static boolean validatePreKeySignature(final Optional maybeIdentityKey, final Optional maybeSignedPreKey) { return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index 300b97454..844e5da7a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -18,14 +18,15 @@ import java.util.Optional; import java.util.UUID; import java.util.function.Predicate; import javax.annotation.Nullable; +import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; -import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; +import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.Util; public class Account { @@ -66,14 +67,14 @@ public class Account { private List devices = new ArrayList<>(); @JsonProperty - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - private byte[] identityKey; + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + private IdentityKey identityKey; @JsonProperty("pniIdentityKey") - @JsonSerialize(using = ByteArrayAdapter.Serializing.class) - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) - private byte[] phoneNumberIdentityKey; + @JsonSerialize(using = IdentityKeyAdapter.Serializer.class) + @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) + private IdentityKey phoneNumberIdentityKey; @JsonProperty("cpv") private String currentProfileVersion; @@ -327,23 +328,23 @@ public class Account { this.canonicallyDiscoverable = canonicallyDiscoverable; } - public void setIdentityKey(byte[] identityKey) { + public void setIdentityKey(final IdentityKey identityKey) { requireNotStale(); this.identityKey = identityKey; } - public byte[] getIdentityKey() { + public IdentityKey getIdentityKey() { requireNotStale(); return identityKey; } - public byte[] getPhoneNumberIdentityKey() { + public IdentityKey getPhoneNumberIdentityKey() { return phoneNumberIdentityKey; } - public void setPhoneNumberIdentityKey(final byte[] phoneNumberIdentityKey) { + public void setPhoneNumberIdentityKey(final IdentityKey phoneNumberIdentityKey) { this.phoneNumberIdentityKey = phoneNumberIdentityKey; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 678b420ec..88212fde8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -30,10 +30,6 @@ import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import org.signal.libsignal.protocol.IdentityKey; -import org.signal.libsignal.protocol.InvalidKeyException; -import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -81,8 +77,6 @@ public class Accounts extends AbstractDynamoDbStore { private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); - private static final String INVALID_IDENTITY_KEY_COUNTER_NAME = name(Accounts.class, "invalidIdentityKey"); - private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed"; private static final String TRANSACTION_CONFLICT = "TransactionConflict"; @@ -915,9 +909,6 @@ public class Accounts extends AbstractDynamoDbStore { .map(AttributeValue::bool) .orElse(false)); - checkIdentityKey(account.getUuid(), account.getIdentityKey(), "aci"); - checkIdentityKey(account.getUuid(), account.getPhoneNumberIdentityKey(), "pni"); - return account; } catch (final IOException e) { @@ -925,19 +916,6 @@ public class Accounts extends AbstractDynamoDbStore { } } - private static void checkIdentityKey(final UUID accountIdentifier, @Nullable final byte[] identityKey, final String keyType) { - if (identityKey != null && identityKey.length > 0) { - try { - new IdentityKey(identityKey); - } catch (final InvalidKeyException e) { - if (identityKey.length != ECPublicKey.KEY_SIZE - 1) { - log.warn("Account {} has an invalid {} identity key; length = {}", accountIdentifier, keyType, identityKey.length); - Metrics.counter(INVALID_IDENTITY_KEY_COUNTER_NAME, "type", keyType).increment(); - } - } - } - } - private static boolean conditionalCheckFailed(final CancellationReason reason) { return CONDITIONAL_CHECK_FAILED.equals(reason.code()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index e41eef878..8e7f08a17 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -27,6 +27,7 @@ import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.UUID; @@ -38,6 +39,7 @@ import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; +import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; @@ -255,7 +257,7 @@ public class AccountsManager { public Account changeNumber(final Account account, final String targetNumber, - @Nullable final byte[] pniIdentityKey, + @Nullable final IdentityKey pniIdentityKey, @Nullable final Map pniSignedPreKeys, @Nullable final Map pniPqLastResortPreKeys, @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { @@ -347,7 +349,7 @@ public class AccountsManager { } public Account updatePniKeys(final Account account, - final byte[] pniIdentityKey, + final IdentityKey pniIdentityKey, final Map pniSignedPreKeys, @Nullable final Map pniPqLastResortPreKeys, final Map pniRegistrationIds) throws MismatchedDevicesException { @@ -366,7 +368,7 @@ public class AccountsManager { } private boolean setPniKeys(final Account account, - @Nullable final byte[] pniIdentityKey, + @Nullable final IdentityKey pniIdentityKey, @Nullable final Map pniSignedPreKeys, @Nullable final Map pniRegistrationIds) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { @@ -375,7 +377,7 @@ public class AccountsManager { throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); } - boolean changed = !Arrays.equals(pniIdentityKey, account.getPhoneNumberIdentityKey()); + boolean changed = !Objects.equals(pniIdentityKey, account.getPhoneNumberIdentityKey()); for (Device device : account.getDevices()) { if (!device.isEnabled()) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index 9de38496d..e754a5fe3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -6,7 +6,14 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; +import org.signal.libsignal.protocol.IdentityKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.AccountController; @@ -19,12 +26,6 @@ import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; -import javax.annotation.Nullable; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; public class ChangeNumberManager { private static final Logger logger = LoggerFactory.getLogger(AccountController.class); @@ -39,7 +40,7 @@ public class ChangeNumberManager { } public Account changeNumber(final Account account, final String number, - @Nullable final byte[] pniIdentityKey, + @Nullable final IdentityKey pniIdentityKey, @Nullable final Map deviceSignedPreKeys, @Nullable final Map devicePqLastResortPreKeys, @Nullable final List deviceMessages, @@ -79,7 +80,7 @@ public class ChangeNumberManager { } public Account updatePniKeys(final Account account, - final byte[] pniIdentityKey, + final IdentityKey pniIdentityKey, final Map deviceSignedPreKeys, @Nullable final Map devicePqLastResortPreKeys, final List deviceMessages, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/IdentityKeyAdapter.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/IdentityKeyAdapter.java new file mode 100644 index 000000000..e79ea6d6b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/IdentityKeyAdapter.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.util.Base64; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.InvalidKeyException; +import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; + +public class IdentityKeyAdapter { + + private static final Counter IDENTITY_KEY_WITHOUT_VERSION_BYTE_COUNTER = + Metrics.counter(MetricsUtil.name(IdentityKeyAdapter.class), "identityKeyWithoutVersionByte"); + + public static class Serializer extends JsonSerializer { + + @Override + public void serialize(final IdentityKey identityKey, + final JsonGenerator jsonGenerator, + final SerializerProvider serializers) throws IOException { + + jsonGenerator.writeString(Base64.getEncoder().encodeToString(identityKey.serialize())); + } + } + + public static class Deserializer extends JsonDeserializer { + + @Override + public IdentityKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException { + final byte[] identityKeyBytes; + + try { + identityKeyBytes = Base64.getDecoder().decode(parser.getValueAsString()); + } catch (final IllegalArgumentException e) { + throw new JsonParseException(parser, "Could not parse identity key as a base64-encoded value", e); + } + + try { + return new IdentityKey(identityKeyBytes); + } catch (final InvalidKeyException e) { + if (identityKeyBytes.length == ECPublicKey.KEY_SIZE - 1) { + IDENTITY_KEY_WITHOUT_VERSION_BYTE_COUNTER.increment(); + return new IdentityKey(ECPublicKey.fromPublicKeyBytes(identityKeyBytes)); + } + + throw new JsonParseException(parser, "Could not interpret identity key bytes as an EC public key", e); + } + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalBase64ByteArrayDeserializer.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalBase64ByteArrayDeserializer.java deleted file mode 100644 index 677dd7af3..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalBase64ByteArrayDeserializer.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.whispersystems.textsecuregcm.util; - -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; - -import java.io.IOException; -import java.util.Base64; -import java.util.Optional; - -public class OptionalBase64ByteArrayDeserializer extends JsonDeserializer> { - - @Override - public Optional deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) throws IOException { - return Optional.of(Base64.getDecoder().decode(jsonParser.getValueAsString())); - } - - @Override - public Optional getNullValue(DeserializationContext ctxt) { - return Optional.empty(); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalIdentityKeyAdapter.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalIdentityKeyAdapter.java new file mode 100644 index 000000000..61a4b97cb --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/OptionalIdentityKeyAdapter.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.util.Base64; +import java.util.Optional; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.InvalidKeyException; + +public class OptionalIdentityKeyAdapter { + + public static class Serializer extends JsonSerializer> { + + @Override + public void serialize(final Optional maybePublicKey, + final JsonGenerator jsonGenerator, + final SerializerProvider serializers) throws IOException { + + if (maybePublicKey.isPresent()) { + jsonGenerator.writeString(Base64.getEncoder().encodeToString(maybePublicKey.get().serialize())); + } else { + jsonGenerator.writeNull(); + } + } + } + + public static class Deserializer extends JsonDeserializer> { + + @Override + public Optional deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) throws IOException { + try { + return Optional.of(new IdentityKey(Base64.getDecoder().decode(jsonParser.getValueAsString()))); + } catch (final InvalidKeyException e) { + throw new IOException(e); + } + } + + @Override + public Optional getNullValue(DeserializationContext ctxt) { + return Optional.empty(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java index f5ea93e34..959412e3e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java @@ -14,7 +14,9 @@ import java.security.InvalidKeyException; import java.util.Base64; import java.util.UUID; import org.junit.jupiter.api.Test; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -22,7 +24,7 @@ class CertificateGeneratorTest { private static final String SIGNING_CERTIFICATE = "CiUIDBIhBbTz4h1My+tt+vw+TVscgUe/DeHS0W02tPWAWbTO2xc3EkD+go4bJnU0AcnFfbOLKoiBfCzouZtDYMOVi69rE7r4U9cXREEqOkUmU2WJBjykAxWPCcSTmVTYHDw7hkSp/puG"; private static final String SIGNING_KEY = "ABOxG29xrfq4E7IrW11Eg7+HBbtba9iiS0500YoBjn4="; - private static final byte[] IDENTITY_KEY = Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"); + private static final IdentityKey IDENTITY_KEY = new IdentityKey(ECPublicKey.fromPublicKeyBytes(Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"))); @Test void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 5f4c1a02a..54f41aaa5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -70,6 +70,7 @@ import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.usernames.BaseUsernameException; @@ -339,7 +340,7 @@ class AccountControllerTest { when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0); final String number = invocation.getArgument(1); - final byte[] pniIdentityKey = invocation.getArgument(2); + final IdentityKey pniIdentityKey = invocation.getArgument(2); final UUID uuid = account.getUuid(); final UUID pni = number.equals(account.getNumber()) ? account.getPhoneNumberIdentifier() : UUID.randomUUID(); @@ -362,7 +363,7 @@ class AccountControllerTest { when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0); - final byte[] pniIdentityKey = invocation.getArgument(1); + final IdentityKey pniIdentityKey = invocation.getArgument(1); final String number = account.getNumber(); final UUID uuid = account.getUuid(); @@ -1646,7 +1647,7 @@ class AccountControllerTest { final String number = "+18005559876"; final String code = "987654"; final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); Device device2 = mock(Device.class); @@ -1700,7 +1701,7 @@ class AccountControllerTest { void testChangePhoneNumberSameNumberChangePrekeys() throws Exception { final String code = "987654"; final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); Device device2 = mock(Device.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index fb1a9917d..5ad94b169 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -61,6 +61,7 @@ import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -97,6 +98,8 @@ class AccountControllerV2Test { private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds(); + private static final IdentityKey IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); @@ -140,7 +143,7 @@ class AccountControllerV2Test { (Answer) invocation -> { final Account account = invocation.getArgument(0); final String number = invocation.getArgument(1); - final byte[] pniIdentityKey = invocation.getArgument(2); + final IdentityKey pniIdentityKey = invocation.getArgument(2); final UUID uuid = account.getUuid(); final List devices = account.getDevices(); @@ -180,7 +183,7 @@ class AccountControllerV2Test { .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity( - new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123".getBytes(StandardCharsets.UTF_8), + new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()), Collections.emptyList(), Collections.emptyMap(), null, Collections.emptyMap()), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -203,7 +206,7 @@ class AccountControllerV2Test { AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity( new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, - "pni-identity-key".getBytes(StandardCharsets.UTF_8), + new IdentityKey(Curve.generateKeyPair().getPublicKey()), Collections.emptyList(), Collections.emptyMap(), null, Collections.emptyMap()), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -407,12 +410,12 @@ class AccountControllerV2Test { "recoveryPassword": "%s", "number": "%s", "reglock": "1234", - "pniIdentityKey": "5678", + "pniIdentityKey": "%s", "deviceMessages": [], "devicePniSignedPrekeys": {}, "pniRegistrationIds": {} } - """, encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber); + """, encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize())); } /** @@ -463,7 +466,7 @@ class AccountControllerV2Test { when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer( (Answer) invocation -> { final Account account = invocation.getArgument(0); - final byte[] pniIdentityKey = invocation.getArgument(1); + final IdentityKey pniIdentityKey = invocation.getArgument(1); final UUID uuid = account.getUuid(); final UUID pni = account.getPhoneNumberIdentifier(); @@ -498,7 +501,7 @@ class AccountControllerV2Test { AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(requestJson()), AccountIdentityResponse.class); - verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key".getBytes(StandardCharsets.UTF_8)), any(), any(), any(), any()); + verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq(IDENTITY_KEY), any(), any(), any(), any()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); @@ -562,7 +565,7 @@ class AccountControllerV2Test { "devicePniSignedPqPrekeys": {}, "pniRegistrationIds": {} } - """, Base64.getEncoder().encodeToString("pni-identity-key".getBytes(StandardCharsets.UTF_8))); + """, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize())); } /** @@ -798,8 +801,8 @@ class AccountControllerV2Test { account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess); account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber); account.setBadges(Clock.systemUTC(), new ArrayList<>(badges)); - account.setIdentityKey(aciIdentityKeyPair.getPublicKey().serialize()); - account.setPhoneNumberIdentityKey(pniIdentityKeyPair.getPublicKey().serialize()); + account.setIdentityKey(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + account.setPhoneNumberIdentityKey(new IdentityKey(pniIdentityKeyPair.getPublicKey())); assert !devices.isEmpty(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index 1d5dc5ed1..1faa26971 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -33,12 +33,12 @@ import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; -import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HexFormat; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.Executors; @@ -61,6 +61,8 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.ServerPublicParams; import org.signal.libsignal.zkgroup.ServerSecretParams; @@ -124,10 +126,10 @@ class ProfileControllerTest { private static final ServerZkProfileOperations zkProfileOperations = mock(ServerZkProfileOperations.class); private static final byte[] UNIDENTIFIED_ACCESS_KEY = "test-uak".getBytes(StandardCharsets.UTF_8); - private static final byte[] ACCOUNT_IDENTITY_KEY = "barz".getBytes(StandardCharsets.UTF_8); - private static final byte[] ACCOUNT_PHONE_NUMBER_IDENTITY_KEY = "bazz".getBytes(StandardCharsets.UTF_8); - private static final byte[] ACCOUNT_TWO_IDENTITY_KEY = "bar".getBytes(StandardCharsets.UTF_8); - private static final byte[] ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY = "baz".getBytes(StandardCharsets.UTF_8); + private static final IdentityKey ACCOUNT_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + private static final IdentityKey ACCOUNT_PHONE_NUMBER_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + private static final IdentityKey ACCOUNT_TWO_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + private static final IdentityKey ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey()); private static final String BASE_64_URL_USERNAME_HASH = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE"; private static final byte[] USERNAME_HASH = Base64.getUrlDecoder().decode(BASE_64_URL_USERNAME_HASH); @SuppressWarnings("unchecked") @@ -1170,26 +1172,31 @@ class ProfileControllerTest { final Condition isAnExpectedUuid = new Condition<>(element -> { if (AuthHelper.VALID_UUID.equals(element.aci())) { - return Arrays.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); + return Objects.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { - return Arrays.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); + return Objects.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); } else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) { - return Arrays.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey()); + return Objects.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey()); } else { return false; } }, "is an expected UUID with the correct identity key"); + final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + final IdentityKey secondValidAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + final IdentityKey invalidAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.json(new BatchIdentityCheckRequest(List.of( new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, - convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), + convertKeyToFingerprint(validAciIdentityKey)), new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, - convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), + convertKeyToFingerprint(secondValidPniIdentityKey)), new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, - convertKeyToFingerprint("another2".getBytes(StandardCharsets.UTF_8))), + convertKeyToFingerprint(secondValidAciIdentityKey)), new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, - convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8))) + convertKeyToFingerprint(invalidAciIdentityKey)) ))))) { assertThat(response).isNotNull(); assertThat(response.getStatus()).isEqualTo(200); @@ -1202,13 +1209,13 @@ class ProfileControllerTest { } final List largeElementList = new ArrayList<>(List.of( - new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), - new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), - new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8))))); + new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint(validAciIdentityKey)), + new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint(secondValidPniIdentityKey)), + new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint(invalidAciIdentityKey)))); for (int i = 0; i < 900; i++) { largeElementList.add( - new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint("abcd".getBytes(StandardCharsets.UTF_8)))); + new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); } try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() @@ -1228,9 +1235,9 @@ class ProfileControllerTest { final Condition isAnExpectedUuid = new Condition<>(element -> { if (AuthHelper.VALID_UUID.equals(element.aci())) { - return Arrays.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); + return ACCOUNT_IDENTITY_KEY.equals(element.identityKey()); } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { - return Arrays.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); + return ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY.equals(element.identityKey()); } else { return false; } @@ -1245,9 +1252,9 @@ class ProfileControllerTest { { "aci": "%s", "fingerprint": "%s" } ] } - """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), - AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), - AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8)))); + """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), + AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))), + AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))); try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() .post(Entity.entity(json, "application/json"))) { @@ -1313,15 +1320,15 @@ class ProfileControllerTest { ] } """, AuthHelper.VALID_UUID, AuthHelper.VALID_PNI, - Base64.getEncoder().encodeToString(convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))))) + Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))))) ); } - private static byte[] convertKeyToFingerprint(byte[] key) { + private static byte[] convertKeyToFingerprint(final IdentityKey publicKey) { try { - return Util.truncate(MessageDigest.getInstance("SHA-256").digest(key), 4); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); + return Util.truncate(MessageDigest.getInstance("SHA-256").digest(publicKey.serialize()), 4); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("All Java implementations must support SHA-256 MessageDigest algorithm", e); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index 8f9e53707..8703e69f4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; @@ -415,8 +416,8 @@ class RegistrationControllerTest { } static Stream atomicAccountCreationConflictingChannel() { - final Optional aciIdentityKey; - final Optional pniIdentityKey; + final Optional aciIdentityKey; + final Optional pniIdentityKey; final Optional aciSignedPreKey; final Optional pniSignedPreKey; final Optional aciPqLastResortPreKey; @@ -425,8 +426,8 @@ class RegistrationControllerTest { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); - pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); + aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); @@ -504,8 +505,8 @@ class RegistrationControllerTest { } static Stream atomicAccountCreationPartialSignedPreKeys() { - final Optional aciIdentityKey; - final Optional pniIdentityKey; + final Optional aciIdentityKey; + final Optional pniIdentityKey; final Optional aciSignedPreKey; final Optional pniSignedPreKey; final Optional aciPqLastResortPreKey; @@ -514,8 +515,8 @@ class RegistrationControllerTest { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); - pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); + aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); @@ -617,8 +618,8 @@ class RegistrationControllerTest { @MethodSource @SuppressWarnings("OptionalUsedAsFieldOrParameterType") void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, - final byte[] expectedAciIdentityKey, - final byte[] expectedPniIdentityKey, + final IdentityKey expectedAciIdentityKey, + final IdentityKey expectedPniIdentityKey, final SignedPreKey expectedAciSignedPreKey, final SignedPreKey expectedPniSignedPreKey, final SignedPreKey expectedAciPqLastResortPreKey, @@ -683,8 +684,8 @@ class RegistrationControllerTest { } private static Stream atomicAccountCreationSuccess() { - final Optional aciIdentityKey; - final Optional pniIdentityKey; + final Optional aciIdentityKey; + final Optional pniIdentityKey; final Optional aciSignedPreKey; final Optional pniSignedPreKey; final Optional aciPqLastResortPreKey; @@ -693,8 +694,8 @@ class RegistrationControllerTest { final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); - pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); + aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 01c94b996..bced40252 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -5,7 +5,9 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -21,6 +23,7 @@ import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -155,7 +158,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final Map preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); @@ -172,7 +175,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); - assertArrayEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); + assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); assertEquals(OptionalInt.of(rotatedPniRegistrationId), updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentityRegistrationId()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 7a105ffd3..9e3e799df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -20,7 +20,6 @@ import static org.mockito.Mockito.when; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Instant; import java.util.ArrayList; @@ -38,6 +37,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.AccountAttributes; @@ -147,7 +148,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { final boolean discoverableByPhoneNumber = false; final String currentProfileVersion = "cpv"; - final byte[] identityKey = "ikey".getBytes(StandardCharsets.UTF_8); + final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final byte[] unidentifiedAccessKey = new byte[]{1}; final String pin = "1234"; final String registrationLock = "reglock"; @@ -189,12 +190,12 @@ class AccountsManagerConcurrentModificationIntegrationTest { return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class); } - private void verifyAccount(final String name, final Account account, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final byte[] identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAccess, final long lastSeen) { + private void verifyAccount(final String name, final Account account, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final IdentityKey identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAccess, final long lastSeen) { assertAll(name, () -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()), () -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().orElseThrow()), - () -> assertArrayEquals(identityKey, account.getIdentityKey()), + () -> assertEquals(identityKey, account.getIdentityKey()), () -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().orElseThrow()), () -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock)), () -> assertEquals(unrestrictedUnidentifiedAccess, account.isUnrestrictedUnidentifiedAccess()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 66665c3a2..7260c5e4e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -5,7 +5,12 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; @@ -24,8 +29,6 @@ import static org.mockito.Mockito.when; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; - -import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; import java.util.ArrayList; @@ -46,6 +49,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.stubbing.Answer; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -473,10 +477,12 @@ class AccountsManagerTest { .doAnswer(ACCOUNT_UPDATE_ANSWER) .when(accounts).update(any()); - account = accountsManager.update(account, a -> a.setIdentityKey("identity-key".getBytes(StandardCharsets.UTF_8))); + final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + account = accountsManager.update(account, a -> a.setIdentityKey(identityKey)); assertEquals(1, account.getVersion()); - assertArrayEquals("identity-key".getBytes(StandardCharsets.UTF_8), account.getIdentityKey()); + assertEquals(identityKey, account.getIdentityKey()); verify(accounts, times(1)).getByAccountIdentifier(uuid); verify(accounts, times(2)).update(any()); @@ -669,7 +675,7 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); assertThrows(IllegalArgumentException.class, () -> accountsManager.changeNumber( - account, number, "new-identity-key".getBytes(StandardCharsets.UTF_8), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), + account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); verify(accounts, never()).update(any()); @@ -728,7 +734,7 @@ class AccountsManagerTest { final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); final Account updatedAccount = accountsManager.changeNumber( - account, targetNumber, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, newSignedPqKeys, newRegistrationIds); + account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); assertEquals(targetNumber, updatedAccount.getNumber()); @@ -771,7 +777,9 @@ class AccountsManagerTest { UUID oldPni = account.getPhoneNumberIdentifier(); Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); - final Account updatedAccount = accountsManager.updatePniKeys(account, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, null, newRegistrationIds); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds); // non-PNI stuff should not change assertEquals(oldUuid, updatedAccount.getUuid()); @@ -783,7 +791,7 @@ class AccountsManagerTest { updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); // PNI stuff should - assertArrayEquals("new-pni-identity-key".getBytes(StandardCharsets.UTF_8), updatedAccount.getPhoneNumberIdentityKey()); + assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); assertEquals(newSignedKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); assertEquals(newRegistrationIds, @@ -817,8 +825,10 @@ class AccountsManagerTest { Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + final Account updatedAccount = - accountsManager.updatePniKeys(account, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, newSignedPqKeys, newRegistrationIds); + accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds); // non-PNI-keys stuff should not change assertEquals(oldUuid, updatedAccount.getUuid()); @@ -830,7 +840,7 @@ class AccountsManagerTest { updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); // PNI keys should - assertArrayEquals("new-pni-identity-key".getBytes(StandardCharsets.UTF_8), updatedAccount.getPhoneNumberIdentityKey()); + assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); assertEquals(newSignedKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); assertEquals(newRegistrationIds, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 35dd0377f..fcad08cb0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -14,7 +14,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Base64; import java.util.Collections; @@ -27,6 +26,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos; @@ -106,7 +107,7 @@ public class ChangeNumberManagerTest { Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); var prekeys = Map.of(1L, new SignedPreKey()); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(StandardCharsets.UTF_8); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); @@ -132,7 +133,7 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); @@ -175,7 +176,7 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(StandardCharsets.UTF_8); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); @@ -217,7 +218,7 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); @@ -257,7 +258,7 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); @@ -296,7 +297,7 @@ public class ChangeNumberManagerTest { when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); - final byte[] pniIdentityKey = "pni-identity-key".getBytes(); + final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); @@ -347,7 +348,7 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, - () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key".getBytes(StandardCharsets.UTF_8), preKeys, null, messages, registrationIds)); + () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds)); } @Test @@ -377,7 +378,7 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, - () -> changeNumberManager.updatePniKeys(account, "pni-identity-key".getBytes(StandardCharsets.UTF_8), preKeys, null, messages, registrationIds)); + () -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds)); } @Test @@ -406,6 +407,6 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(IllegalArgumentException.class, - () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key".getBytes(StandardCharsets.UTF_8), null, null, messages, registrationIds)); + () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), null, null, messages, registrationIds)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java index d9f9892ee..9807a2e66 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java @@ -113,7 +113,7 @@ class CertificateControllerTest { assertEquals(certificate.getSenderDevice(), 1L); assertTrue(certificate.hasSenderUuid()); assertEquals(AuthHelper.VALID_UUID.toString(), certificate.getSenderUuid()); - assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); + assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize()); } @Test @@ -141,7 +141,7 @@ class CertificateControllerTest { assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER); assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString()); - assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); + assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize()); } @Test @@ -170,7 +170,7 @@ class CertificateControllerTest { assertTrue(StringUtils.isBlank(certificate.getSender())); assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString()); - assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); + assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index bdf04aec0..a693ff785 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -42,6 +42,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -277,8 +278,8 @@ class DeviceControllerTest { aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - when(account.getIdentityKey()).thenReturn(aciIdentityKeyPair.getPublicKey().serialize()); - when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKeyPair.getPublicKey().serialize()); + when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); final LinkDeviceRequest request = new LinkDeviceRequest("5678901", new AccountAttributes(fetchesMessages, 1234, null, null, true, null), @@ -363,8 +364,8 @@ class DeviceControllerTest { aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - when(account.getIdentityKey()).thenReturn(aciIdentityKeyPair.getPublicKey().serialize()); - when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKeyPair.getPublicKey().serialize()); + when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); final LinkDeviceRequest request = new LinkDeviceRequest("5678901", new AccountAttributes(fetchesMessages, 1234, null, null, true, null), @@ -392,8 +393,8 @@ class DeviceControllerTest { @ParameterizedTest @MethodSource @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - void linkDeviceAtomicMissingProperty(final byte[] aciIdentityKey, - final byte[] pniIdentityKey, + void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey, + final IdentityKey pniIdentityKey, final Optional aciSignedPreKey, final Optional pniSignedPreKey, final Optional aciPqLastResortPreKey, @@ -439,8 +440,8 @@ class DeviceControllerTest { final Optional aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); final Optional pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - final byte[] aciIdentityKey = aciIdentityKeyPair.getPublicKey().serialize(); - final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); + final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); return Stream.of( Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), @@ -452,8 +453,8 @@ class DeviceControllerTest { @ParameterizedTest @MethodSource - void linkDeviceAtomicInvalidSignature(final byte[] aciIdentityKey, - final byte[] pniIdentityKey, + void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey, + final IdentityKey pniIdentityKey, final SignedPreKey aciSignedPreKey, final SignedPreKey pniSignedPreKey, final SignedPreKey aciPqLastResortPreKey, @@ -499,8 +500,8 @@ class DeviceControllerTest { final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); - final byte[] aciIdentityKey = aciIdentityKeyPair.getPublicKey().serialize(); - final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); + final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey()); + final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); return Stream.of( Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 0b5f12a3f..edbce20b4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -39,6 +39,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; +import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -80,10 +81,10 @@ class KeysControllerTest { private static final int SAMPLE_PNI_REGISTRATION_ID = 1717; private final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); - private final byte[] IDENTITY_KEY = IDENTITY_KEY_PAIR.getPublicKey().serialize(); + private final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey()); private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); - private final byte[] PNI_IDENTITY_KEY = PNI_IDENTITY_KEY_PAIR.getPublicKey().serialize(); + private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey()); private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); @@ -658,7 +659,7 @@ class KeysControllerTest { final PreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); @@ -688,7 +689,7 @@ class KeysControllerTest { final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); @@ -716,7 +717,7 @@ class KeysControllerTest { @Test void putKeysStructurallyInvalidSignedECKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair); final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null); @@ -729,11 +730,11 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(422); } - + @Test void putKeysStructurallyInvalidUnsignedECKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes()); final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null); @@ -746,11 +747,11 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(422); } - + @Test void putKeysStructurallyInvalidPQOneTimeKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null); @@ -763,11 +764,11 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(422); } - + @Test void putKeysStructurallyInvalidPQLastResortKey() { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey); @@ -780,13 +781,13 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(422); } - + @Test void putKeysByPhoneNumberIdentifierTestV2() { final PreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); @@ -817,7 +818,7 @@ class KeysControllerTest { final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); @@ -860,10 +861,10 @@ class KeysControllerTest { @Test void disabledPutKeysTestV2() { - final PreKey preKey = KeysHelper.ecPreKey(31337); + final PreKey preKey = KeysHelper.ecPreKey(31337); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); - final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); + final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey()); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 1c421c960..0a92c212b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -20,6 +20,8 @@ import java.util.Base64; import java.util.Optional; import java.util.Random; import java.util.UUID; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; @@ -63,7 +65,8 @@ public class AuthHelper { public static final UUID UNDISCOVERABLE_UUID = UUID.randomUUID(); public static final String UNDISCOVERABLE_PASSWORD = "IT'S A SECRET TO EVERYBODY."; - public static final byte[] VALID_IDENTITY = Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"); + public static final IdentityKey VALID_IDENTITY = new IdentityKey(ECPublicKey.fromPublicKeyBytes( + Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"))); public static AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); public static Account VALID_ACCOUNT = mock(Account.class );