Represent identity keys as `IdentityKey` instances

This commit is contained in:
Jon Chambers 2023-06-08 11:36:58 -04:00 committed by GitHub
parent 1c8443210a
commit 234707169e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 390 additions and 263 deletions

View File

@ -29,6 +29,7 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate; import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.signal.integration.config.Config; import org.signal.integration.config.Config;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyPair; import org.signal.libsignal.protocol.kem.KEMKeyPair;
@ -109,8 +110,8 @@ public final class Operations {
registrationPassword, registrationPassword,
accountAttributes, accountAttributes,
true, true,
Optional.of(aciIdentityKeyPair.getPublicKey().serialize()), Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())),
Optional.of(pniIdentityKeyPair.getPublicKey().serialize()), Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())),
Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)), Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)),
Optional.of(generateSignedECPreKey(2, pniIdentityKeyPair)), Optional.of(generateSignedECPreKey(2, pniIdentityKeyPair)),
Optional.of(generateSignedKEMPreKey(3, aciIdentityKeyPair)), Optional.of(generateSignedKEMPreKey(3, aciIdentityKeyPair)),

View File

@ -34,7 +34,7 @@ public class CertificateGenerator {
SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder() SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder()
.setSenderDevice(Math.toIntExact(device.getId())) .setSenderDevice(Math.toIntExact(device.getId()))
.setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays)) .setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays))
.setIdentityKey(ByteString.copyFrom(account.getIdentityKey())) .setIdentityKey(ByteString.copyFrom(account.getIdentityKey().serialize()))
.setSigner(serverCertificate) .setSigner(serverCertificate)
.setSenderUuid(account.getUuid().toString()); .setSenderUuid(account.getUuid().toString());

View File

@ -32,7 +32,6 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.apache.commons.lang3.ArrayUtils;
import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations; import org.signal.libsignal.zkgroup.auth.ServerZkAuthOperations;
import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse; import org.signal.libsignal.zkgroup.calllinks.CallLinkAuthCredentialResponse;
import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.GenericServerSecretParams;
@ -75,7 +74,7 @@ public class CertificateController {
@QueryParam("includeE164") @DefaultValue("true") boolean includeE164) @QueryParam("includeE164") @DefaultValue("true") boolean includeE164)
throws InvalidKeyException { throws InvalidKeyException {
if (ArrayUtils.isEmpty(auth.getAccount().getIdentityKey())) { if (auth.getAccount().getIdentityKey() == null) {
throw new WebApplicationException(Response.Status.BAD_REQUEST); throw new WebApplicationException(Response.Status.BAD_REQUEST);
} }

View File

@ -11,14 +11,14 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import javax.validation.Valid; import javax.validation.Valid;
@ -35,8 +35,7 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.signal.libsignal.protocol.IdentityKey;
import org.apache.commons.lang3.ArrayUtils;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
@ -116,11 +115,11 @@ public class KeysController {
updateAccount = true; updateAccount = true;
} }
final byte[] oldIdentityKey = usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); final IdentityKey oldIdentityKey = usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey();
if (!Arrays.equals(preKeys.getIdentityKey(), oldIdentityKey)) { if (!Objects.equals(preKeys.getIdentityKey(), oldIdentityKey)) {
updateAccount = true; updateAccount = true;
final boolean hasIdentityKey = ArrayUtils.isNotEmpty(oldIdentityKey); final boolean hasIdentityKey = oldIdentityKey != null;
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)) final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))
.and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey)) .and(HAS_IDENTITY_KEY_TAG_NAME, String.valueOf(hasIdentityKey))
.and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci"); .and(IDENTITY_TYPE_TAG_NAME, usePhoneNumberIdentity ? "pni" : "aci");
@ -221,7 +220,7 @@ public class KeysController {
} }
} }
final byte[] identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); final IdentityKey identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey();
if (responseItems.isEmpty()) { if (responseItems.isEmpty()) {
return Response.status(404).build(); return Response.status(404).build();

View File

@ -64,6 +64,7 @@ import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredentialResponse; import org.signal.libsignal.zkgroup.profiles.ExpiringProfileKeyCredentialResponse;
@ -383,14 +384,14 @@ public class ProfileController {
if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) { if (account.getIdentityKey() == null || account.getPhoneNumberIdentityKey() == null) {
return; return;
} }
final byte[] identityKeyBytes = final IdentityKey identityKey =
usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey(); usePhoneNumberIdentity ? account.getPhoneNumberIdentityKey() : account.getIdentityKey();
md.reset(); md.reset();
byte[] digest = md.digest(identityKeyBytes); byte[] digest = md.digest(identityKey.serialize());
byte[] fingerprint = Util.truncate(digest, 4); byte[] fingerprint = Util.truncate(digest, 4);
if (!Arrays.equals(fingerprint, element.fingerprint())) { if (!Arrays.equals(fingerprint, element.fingerprint())) {
responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKeyBytes)); responseElements.add(new BatchIdentityCheckResponse.Element(element.aci(), element.uuid(), identityKey));
} }
}); });
} }

View File

@ -8,7 +8,9 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
@ -16,9 +18,9 @@ import java.util.UUID;
public class BaseProfileResponse { public class BaseProfileResponse {
@JsonProperty @JsonProperty
@JsonSerialize(using = ByteArrayAdapter.Serializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
private byte[] identityKey; private IdentityKey identityKey;
@JsonProperty @JsonProperty
private String unidentifiedAccess; private String unidentifiedAccess;
@ -38,7 +40,7 @@ public class BaseProfileResponse {
public BaseProfileResponse() { public BaseProfileResponse() {
} }
public BaseProfileResponse(final byte[] identityKey, public BaseProfileResponse(final IdentityKey identityKey,
final String unidentifiedAccess, final String unidentifiedAccess,
final boolean unrestrictedUnidentifiedAccess, final boolean unrestrictedUnidentifiedAccess,
final UserCapabilities capabilities, final UserCapabilities capabilities,
@ -53,7 +55,7 @@ public class BaseProfileResponse {
this.uuid = uuid; this.uuid = uuid;
} }
public byte[] getIdentityKey() { public IdentityKey getIdentityKey() {
return identityKey; return identityKey;
} }

View File

@ -11,13 +11,25 @@ import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ExactlySize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
public record BatchIdentityCheckResponse(@Valid List<Element> elements) { public record BatchIdentityCheckResponse(@Valid List<Element> elements) {
public record Element(@Deprecated @JsonInclude(JsonInclude.Include.NON_EMPTY) @Nullable UUID aci, public record Element(@Deprecated
@JsonInclude(JsonInclude.Include.NON_EMPTY) @Nullable UUID uuid, @JsonInclude(JsonInclude.Include.NON_EMPTY)
@NotNull @ExactlySize(33) byte[] identityKey) { @Nullable UUID aci,
@JsonInclude(JsonInclude.Include.NON_EMPTY)
@Nullable UUID uuid,
@NotNull
@JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
IdentityKey identityKey) {
public Element { public Element {
if (aci == null && uuid == null) { if (aci == null && uuid == null) {

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -15,9 +16,10 @@ import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
@ -39,8 +41,9 @@ public record ChangeNumberRequest(
@JsonProperty("reglock") @Nullable String registrationLock, @JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@NotEmpty byte[] pniIdentityKey, @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@NotNull IdentityKey pniIdentityKey,
@Schema(description=""" @Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keysManager A list of synchronization messages to send to companion devices to supply the private keysManager

View File

@ -7,17 +7,18 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.validation.constraints.AssertTrue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
@ -32,8 +33,9 @@ public record ChangePhoneNumberRequest(
@JsonProperty("reglock") @Nullable String registrationLock, @JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@Nullable byte[] pniIdentityKey, @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@Nullable IdentityKey pniIdentityKey,
@Schema(description=""" @Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keysManager A list of synchronization messages to send to companion devices to supply the private keysManager

View File

@ -5,27 +5,24 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record PhoneNumberIdentityKeyDistributionRequest( public record PhoneNumberIdentityKeyDistributionRequest(
@NotEmpty @NotNull
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@Schema(description="the new identity key for this account's phone-number identity") @Schema(description="the new identity key for this account's phone-number identity")
byte[] pniIdentityKey, IdentityKey pniIdentityKey,
@NotNull @NotNull
@Valid @Valid

View File

@ -6,19 +6,21 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import java.util.List; import java.util.List;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
public class PreKeyResponse { public class PreKeyResponse {
@JsonProperty @JsonProperty
@JsonSerialize(using = ByteArrayAdapter.Serializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@Schema(description="the public identity key for the requested identity") @Schema(description="the public identity key for the requested identity")
private byte[] identityKey; private IdentityKey identityKey;
@JsonProperty @JsonProperty
@Schema(description="information about each requested device") @Schema(description="information about each requested device")
@ -26,13 +28,13 @@ public class PreKeyResponse {
public PreKeyResponse() {} public PreKeyResponse() {}
public PreKeyResponse(byte[] identityKey, List<PreKeyResponseItem> devices) { public PreKeyResponse(IdentityKey identityKey, List<PreKeyResponseItem> devices) {
this.identityKey = identityKey; this.identityKey = identityKey;
this.devices = devices; this.devices = devices;
} }
@VisibleForTesting @VisibleForTesting
public byte[] getIdentityKey() { public IdentityKey getIdentityKey() {
return identityKey; return identityKey;
} }

View File

@ -9,27 +9,23 @@ import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Collection; import java.util.Collection;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
public abstract class PreKeySignatureValidator { public abstract class PreKeySignatureValidator {
public static final Counter INVALID_SIGNATURE_COUNTER = public static final Counter INVALID_SIGNATURE_COUNTER =
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")); Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature"));
public static boolean validatePreKeySignatures(final byte[] identityKeyBytes, final Collection<SignedPreKey> spks) { public static boolean validatePreKeySignatures(final IdentityKey identityKey, final Collection<SignedPreKey> spks) {
try { try {
final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0);
final boolean success = spks.stream() final boolean success = spks.stream()
.allMatch(spk -> identityKey.verifySignature(spk.getPublicKey(), spk.getSignature())); .allMatch(spk -> identityKey.getPublicKey().verifySignature(spk.getPublicKey(), spk.getSignature()));
if (!success) { if (!success) {
INVALID_SIGNATURE_COUNTER.increment(); INVALID_SIGNATURE_COUNTER.increment();
} }
return success; return success;
} catch (IllegalArgumentException | InvalidKeyException e) { } catch (final IllegalArgumentException e) {
INVALID_SIGNATURE_COUNTER.increment(); INVALID_SIGNATURE_COUNTER.increment();
return false; return false;
} }

View File

@ -6,16 +6,16 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey; import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType; import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
@ -55,24 +55,24 @@ public class PreKeyState {
private SignedPreKey pqLastResortPreKey; private SignedPreKey pqLastResortPreKey;
@JsonProperty @JsonProperty
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@NotEmpty @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@NotNull @NotNull
@Schema(description="Required. " + @Schema(description="Required. " +
"The public identity key for this identity (account or phone-number identity). " + "The public identity key for this identity (account or phone-number identity). " +
"If this device is not the primary device for the account, " + "If this device is not the primary device for the account, " +
"must match the existing stored identity key for this identity.") "must match the existing stored identity key for this identity.")
private byte[] identityKey; private IdentityKey identityKey;
public PreKeyState() {} public PreKeyState() {}
@VisibleForTesting @VisibleForTesting
public PreKeyState(byte[] identityKey, SignedPreKey signedPreKey, List<PreKey> keys) { public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List<PreKey> keys) {
this(identityKey, signedPreKey, keys, null, null); this(identityKey, signedPreKey, keys, null, null);
} }
@VisibleForTesting @VisibleForTesting
public PreKeyState(byte[] identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) { public PreKeyState(IdentityKey identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) {
this.identityKey = identityKey; this.identityKey = identityKey;
this.signedPreKey = signedPreKey; this.signedPreKey = signedPreKey;
this.preKeys = keys; this.preKeys = keys;
@ -96,7 +96,7 @@ public class PreKeyState {
return pqLastResortPreKey; return pqLastResortPreKey;
} }
public byte[] getIdentityKey() { public IdentityKey getIdentityKey() {
return identityKey; return identityKey;
} }

View File

@ -12,16 +12,16 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import java.util.List;
import org.whispersystems.textsecuregcm.util.OptionalBase64ByteArrayDeserializer; import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import java.util.List; import org.signal.libsignal.protocol.IdentityKey;
import java.util.Optional; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.OptionalIdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.ValidPreKey;
import org.whispersystems.textsecuregcm.util.ValidPreKey.PreKeyType;
public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ID of an existing verification session as it appears in a verification session The ID of an existing verification session as it appears in a verification session
@ -57,16 +57,18 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
provided, an account will be created "atomically," and all other properties needed for provided, an account will be created "atomically," and all other properties needed for
atomic account creation must also be present. atomic account creation must also be present.
""") """)
@JsonDeserialize(using = OptionalBase64ByteArrayDeserializer.class) @JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class)
Optional<byte[]> aciIdentityKey, @JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class)
Optional<IdentityKey> aciIdentityKey,
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The PNI-associated identity key for the account, encoded as a base64 string. If The PNI-associated identity key for the account, encoded as a base64 string. If
provided, an account will be created "atomically," and all other properties needed for provided, an account will be created "atomically," and all other properties needed for
atomic account creation must also be present. atomic account creation must also be present.
""") """)
@JsonDeserialize(using = OptionalBase64ByteArrayDeserializer.class) @JsonSerialize(using = OptionalIdentityKeyAdapter.Serializer.class)
Optional<byte[]> pniIdentityKey, @JsonDeserialize(using = OptionalIdentityKeyAdapter.Deserializer.class)
Optional<IdentityKey> pniIdentityKey,
@JsonUnwrapped @JsonUnwrapped
@JsonProperty(access = JsonProperty.Access.READ_ONLY) @JsonProperty(access = JsonProperty.Access.READ_ONLY)
@ -78,8 +80,8 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
@JsonProperty("recoveryPassword") byte[] recoveryPassword, @JsonProperty("recoveryPassword") byte[] recoveryPassword,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("aciIdentityKey") Optional<byte[]> aciIdentityKey, @JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey,
@JsonProperty("pniIdentityKey") Optional<byte[]> pniIdentityKey, @JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey, @JsonProperty("aciSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> aciSignedPreKey,
@JsonProperty("pniSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey, @JsonProperty("pniSignedPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.ECC) SignedPreKey> pniSignedPreKey,
@JsonProperty("aciPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> aciPqLastResortPreKey, @JsonProperty("aciPqLastResortPreKey") Optional<@Valid @ValidPreKey(type=PreKeyType.KYBER) SignedPreKey> aciPqLastResortPreKey,
@ -103,7 +105,7 @@ public record RegistrationRequest(@Schema(requiredMode = Schema.RequiredMode.NOT
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private static boolean validatePreKeySignature(final Optional<byte[]> maybeIdentityKey, private static boolean validatePreKeySignature(final Optional<IdentityKey> maybeIdentityKey,
final Optional<SignedPreKey> maybeSignedPreKey) { final Optional<SignedPreKey> maybeSignedPreKey) {
return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey return maybeSignedPreKey.map(signedPreKey -> maybeIdentityKey

View File

@ -18,14 +18,15 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.function.Predicate; import java.util.function.Predicate;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
public class Account { public class Account {
@ -66,14 +67,14 @@ public class Account {
private List<Device> devices = new ArrayList<>(); private List<Device> devices = new ArrayList<>();
@JsonProperty @JsonProperty
@JsonSerialize(using = ByteArrayAdapter.Serializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
private byte[] identityKey; private IdentityKey identityKey;
@JsonProperty("pniIdentityKey") @JsonProperty("pniIdentityKey")
@JsonSerialize(using = ByteArrayAdapter.Serializing.class) @JsonSerialize(using = IdentityKeyAdapter.Serializer.class)
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
private byte[] phoneNumberIdentityKey; private IdentityKey phoneNumberIdentityKey;
@JsonProperty("cpv") @JsonProperty("cpv")
private String currentProfileVersion; private String currentProfileVersion;
@ -327,23 +328,23 @@ public class Account {
this.canonicallyDiscoverable = canonicallyDiscoverable; this.canonicallyDiscoverable = canonicallyDiscoverable;
} }
public void setIdentityKey(byte[] identityKey) { public void setIdentityKey(final IdentityKey identityKey) {
requireNotStale(); requireNotStale();
this.identityKey = identityKey; this.identityKey = identityKey;
} }
public byte[] getIdentityKey() { public IdentityKey getIdentityKey() {
requireNotStale(); requireNotStale();
return identityKey; return identityKey;
} }
public byte[] getPhoneNumberIdentityKey() { public IdentityKey getPhoneNumberIdentityKey() {
return phoneNumberIdentityKey; return phoneNumberIdentityKey;
} }
public void setPhoneNumberIdentityKey(final byte[] phoneNumberIdentityKey) { public void setPhoneNumberIdentityKey(final IdentityKey phoneNumberIdentityKey) {
this.phoneNumberIdentityKey = phoneNumberIdentityKey; this.phoneNumberIdentityKey = phoneNumberIdentityKey;
} }

View File

@ -30,10 +30,6 @@ import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
@ -81,8 +77,6 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset")); private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete")); private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final String INVALID_IDENTITY_KEY_COUNTER_NAME = name(Accounts.class, "invalidIdentityKey");
private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed"; private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed";
private static final String TRANSACTION_CONFLICT = "TransactionConflict"; private static final String TRANSACTION_CONFLICT = "TransactionConflict";
@ -915,9 +909,6 @@ public class Accounts extends AbstractDynamoDbStore {
.map(AttributeValue::bool) .map(AttributeValue::bool)
.orElse(false)); .orElse(false));
checkIdentityKey(account.getUuid(), account.getIdentityKey(), "aci");
checkIdentityKey(account.getUuid(), account.getPhoneNumberIdentityKey(), "pni");
return account; return account;
} catch (final IOException e) { } catch (final IOException e) {
@ -925,19 +916,6 @@ public class Accounts extends AbstractDynamoDbStore {
} }
} }
private static void checkIdentityKey(final UUID accountIdentifier, @Nullable final byte[] identityKey, final String keyType) {
if (identityKey != null && identityKey.length > 0) {
try {
new IdentityKey(identityKey);
} catch (final InvalidKeyException e) {
if (identityKey.length != ECPublicKey.KEY_SIZE - 1) {
log.warn("Account {} has an invalid {} identity key; length = {}", accountIdentifier, keyType, identityKey.length);
Metrics.counter(INVALID_IDENTITY_KEY_COUNTER_NAME, "type", keyType).increment();
}
}
}
}
private static boolean conditionalCheckFailed(final CancellationReason reason) { private static boolean conditionalCheckFailed(final CancellationReason reason) {
return CONDITIONAL_CHECK_FAILED.equals(reason.code()); return CONDITIONAL_CHECK_FAILED.equals(reason.code());
} }

View File

@ -27,6 +27,7 @@ import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt; import java.util.OptionalInt;
import java.util.UUID; import java.util.UUID;
@ -38,6 +39,7 @@ import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
@ -255,7 +257,7 @@ public class AccountsManager {
public Account changeNumber(final Account account, public Account changeNumber(final Account account,
final String targetNumber, final String targetNumber,
@Nullable final byte[] pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys, @Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { @Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
@ -347,7 +349,7 @@ public class AccountsManager {
} }
public Account updatePniKeys(final Account account, public Account updatePniKeys(final Account account,
final byte[] pniIdentityKey, final IdentityKey pniIdentityKey,
final Map<Long, SignedPreKey> pniSignedPreKeys, final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys, @Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException { final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
@ -366,7 +368,7 @@ public class AccountsManager {
} }
private boolean setPniKeys(final Account account, private boolean setPniKeys(final Account account,
@Nullable final byte[] pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) { @Nullable final Map<Long, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
@ -375,7 +377,7 @@ public class AccountsManager {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
} }
boolean changed = !Arrays.equals(pniIdentityKey, account.getPhoneNumberIdentityKey()); boolean changed = !Objects.equals(pniIdentityKey, account.getPhoneNumberIdentityKey());
for (Device device : account.getDevices()) { for (Device device : account.getDevices()) {
if (!device.isEnabled()) { if (!device.isEnabled()) {

View File

@ -6,7 +6,14 @@ package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
import org.signal.libsignal.protocol.IdentityKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.AccountController;
@ -19,12 +26,6 @@ import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
public class ChangeNumberManager { public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(AccountController.class); private static final Logger logger = LoggerFactory.getLogger(AccountController.class);
@ -39,7 +40,7 @@ public class ChangeNumberManager {
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@Nullable final byte[] pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys, @Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys, @Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages, @Nullable final List<IncomingMessage> deviceMessages,
@ -79,7 +80,7 @@ public class ChangeNumberManager {
} }
public Account updatePniKeys(final Account account, public Account updatePniKeys(final Account account,
final byte[] pniIdentityKey, final IdentityKey pniIdentityKey,
final Map<Long, SignedPreKey> deviceSignedPreKeys, final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys, @Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages, final List<IncomingMessage> deviceMessages,

View File

@ -0,0 +1,64 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
public class IdentityKeyAdapter {
private static final Counter IDENTITY_KEY_WITHOUT_VERSION_BYTE_COUNTER =
Metrics.counter(MetricsUtil.name(IdentityKeyAdapter.class), "identityKeyWithoutVersionByte");
public static class Serializer extends JsonSerializer<IdentityKey> {
@Override
public void serialize(final IdentityKey identityKey,
final JsonGenerator jsonGenerator,
final SerializerProvider serializers) throws IOException {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(identityKey.serialize()));
}
}
public static class Deserializer extends JsonDeserializer<IdentityKey> {
@Override
public IdentityKey deserialize(final JsonParser parser, final DeserializationContext context) throws IOException {
final byte[] identityKeyBytes;
try {
identityKeyBytes = Base64.getDecoder().decode(parser.getValueAsString());
} catch (final IllegalArgumentException e) {
throw new JsonParseException(parser, "Could not parse identity key as a base64-encoded value", e);
}
try {
return new IdentityKey(identityKeyBytes);
} catch (final InvalidKeyException e) {
if (identityKeyBytes.length == ECPublicKey.KEY_SIZE - 1) {
IDENTITY_KEY_WITHOUT_VERSION_BYTE_COUNTER.increment();
return new IdentityKey(ECPublicKey.fromPublicKeyBytes(identityKeyBytes));
}
throw new JsonParseException(parser, "Could not interpret identity key bytes as an EC public key", e);
}
}
}
}

View File

@ -1,22 +0,0 @@
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import java.io.IOException;
import java.util.Base64;
import java.util.Optional;
public class OptionalBase64ByteArrayDeserializer extends JsonDeserializer<Optional<byte[]>> {
@Override
public Optional<byte[]> deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) throws IOException {
return Optional.of(Base64.getDecoder().decode(jsonParser.getValueAsString()));
}
@Override
public Optional<byte[]> getNullValue(DeserializationContext ctxt) {
return Optional.empty();
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import java.util.Base64;
import java.util.Optional;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException;
public class OptionalIdentityKeyAdapter {
public static class Serializer extends JsonSerializer<Optional<IdentityKey>> {
@Override
public void serialize(final Optional<IdentityKey> maybePublicKey,
final JsonGenerator jsonGenerator,
final SerializerProvider serializers) throws IOException {
if (maybePublicKey.isPresent()) {
jsonGenerator.writeString(Base64.getEncoder().encodeToString(maybePublicKey.get().serialize()));
} else {
jsonGenerator.writeNull();
}
}
}
public static class Deserializer extends JsonDeserializer<Optional<IdentityKey>> {
@Override
public Optional<IdentityKey> deserialize(final JsonParser jsonParser, final DeserializationContext deserializationContext) throws IOException {
try {
return Optional.of(new IdentityKey(Base64.getDecoder().decode(jsonParser.getValueAsString())));
} catch (final InvalidKeyException e) {
throw new IOException(e);
}
}
@Override
public Optional<IdentityKey> getNullValue(DeserializationContext ctxt) {
return Optional.empty();
}
}
}

View File

@ -14,7 +14,9 @@ import java.security.InvalidKeyException;
import java.util.Base64; import java.util.Base64;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -22,7 +24,7 @@ class CertificateGeneratorTest {
private static final String SIGNING_CERTIFICATE = "CiUIDBIhBbTz4h1My+tt+vw+TVscgUe/DeHS0W02tPWAWbTO2xc3EkD+go4bJnU0AcnFfbOLKoiBfCzouZtDYMOVi69rE7r4U9cXREEqOkUmU2WJBjykAxWPCcSTmVTYHDw7hkSp/puG"; private static final String SIGNING_CERTIFICATE = "CiUIDBIhBbTz4h1My+tt+vw+TVscgUe/DeHS0W02tPWAWbTO2xc3EkD+go4bJnU0AcnFfbOLKoiBfCzouZtDYMOVi69rE7r4U9cXREEqOkUmU2WJBjykAxWPCcSTmVTYHDw7hkSp/puG";
private static final String SIGNING_KEY = "ABOxG29xrfq4E7IrW11Eg7+HBbtba9iiS0500YoBjn4="; private static final String SIGNING_KEY = "ABOxG29xrfq4E7IrW11Eg7+HBbtba9iiS0500YoBjn4=";
private static final byte[] IDENTITY_KEY = Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"); private static final IdentityKey IDENTITY_KEY = new IdentityKey(ECPublicKey.fromPublicKeyBytes(Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo")));
@Test @Test
void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException { void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException {

View File

@ -70,6 +70,7 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.usernames.BaseUsernameException; import org.signal.libsignal.usernames.BaseUsernameException;
@ -339,7 +340,7 @@ class AccountControllerTest {
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> { when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0); final Account account = invocation.getArgument(0);
final String number = invocation.getArgument(1); final String number = invocation.getArgument(1);
final byte[] pniIdentityKey = invocation.getArgument(2); final IdentityKey pniIdentityKey = invocation.getArgument(2);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
final UUID pni = number.equals(account.getNumber()) ? account.getPhoneNumberIdentifier() : UUID.randomUUID(); final UUID pni = number.equals(account.getNumber()) ? account.getPhoneNumberIdentifier() : UUID.randomUUID();
@ -362,7 +363,7 @@ class AccountControllerTest {
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> { when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0); final Account account = invocation.getArgument(0);
final byte[] pniIdentityKey = invocation.getArgument(1); final IdentityKey pniIdentityKey = invocation.getArgument(1);
final String number = account.getNumber(); final String number = account.getNumber();
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
@ -1646,7 +1647,7 @@ class AccountControllerTest {
final String number = "+18005559876"; final String number = "+18005559876";
final String code = "987654"; final String code = "987654";
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8);
Device device2 = mock(Device.class); Device device2 = mock(Device.class);
@ -1700,7 +1701,7 @@ class AccountControllerTest {
void testChangePhoneNumberSameNumberChangePrekeys() throws Exception { void testChangePhoneNumberSameNumberChangePrekeys() throws Exception {
final String code = "987654"; final String code = "987654";
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8);
Device device2 = mock(Device.class); Device device2 = mock(Device.class);

View File

@ -61,6 +61,7 @@ import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -97,6 +98,8 @@ class AccountControllerV2Test {
private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds(); private static final long SESSION_EXPIRATION_SECONDS = Duration.ofMinutes(10).toSeconds();
private static final IdentityKey IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format( private static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
@ -140,7 +143,7 @@ class AccountControllerV2Test {
(Answer<Account>) invocation -> { (Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0); final Account account = invocation.getArgument(0);
final String number = invocation.getArgument(1); final String number = invocation.getArgument(1);
final byte[] pniIdentityKey = invocation.getArgument(2); final IdentityKey pniIdentityKey = invocation.getArgument(2);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
final List<Device> devices = account.getDevices(); final List<Device> devices = account.getDevices();
@ -180,7 +183,7 @@ class AccountControllerV2Test {
.header(HttpHeaders.AUTHORIZATION, .header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123".getBytes(StandardCharsets.UTF_8), new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), null, Collections.emptyMap()), Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -203,7 +206,7 @@ class AccountControllerV2Test {
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null,
"pni-identity-key".getBytes(StandardCharsets.UTF_8), new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), null, Collections.emptyMap()), Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -407,12 +410,12 @@ class AccountControllerV2Test {
"recoveryPassword": "%s", "recoveryPassword": "%s",
"number": "%s", "number": "%s",
"reglock": "1234", "reglock": "1234",
"pniIdentityKey": "5678", "pniIdentityKey": "%s",
"deviceMessages": [], "deviceMessages": [],
"devicePniSignedPrekeys": {}, "devicePniSignedPrekeys": {},
"pniRegistrationIds": {} "pniRegistrationIds": {}
} }
""", encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber); """, encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()));
} }
/** /**
@ -463,7 +466,7 @@ class AccountControllerV2Test {
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer( when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> { (Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0); final Account account = invocation.getArgument(0);
final byte[] pniIdentityKey = invocation.getArgument(1); final IdentityKey pniIdentityKey = invocation.getArgument(1);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
final UUID pni = account.getPhoneNumberIdentifier(); final UUID pni = account.getPhoneNumberIdentifier();
@ -498,7 +501,7 @@ class AccountControllerV2Test {
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(requestJson()), AccountIdentityResponse.class); .put(Entity.json(requestJson()), AccountIdentityResponse.class);
verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key".getBytes(StandardCharsets.UTF_8)), any(), any(), any(), any()); verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq(IDENTITY_KEY), any(), any(), any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number());
@ -562,7 +565,7 @@ class AccountControllerV2Test {
"devicePniSignedPqPrekeys": {}, "devicePniSignedPqPrekeys": {},
"pniRegistrationIds": {} "pniRegistrationIds": {}
} }
""", Base64.getEncoder().encodeToString("pni-identity-key".getBytes(StandardCharsets.UTF_8))); """, Base64.getEncoder().encodeToString(IDENTITY_KEY.serialize()));
} }
/** /**
@ -798,8 +801,8 @@ class AccountControllerV2Test {
account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess); account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess);
account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber); account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber);
account.setBadges(Clock.systemUTC(), new ArrayList<>(badges)); account.setBadges(Clock.systemUTC(), new ArrayList<>(badges));
account.setIdentityKey(aciIdentityKeyPair.getPublicKey().serialize()); account.setIdentityKey(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
account.setPhoneNumberIdentityKey(pniIdentityKeyPair.getPublicKey().serialize()); account.setPhoneNumberIdentityKey(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
assert !devices.isEmpty(); assert !devices.isEmpty();

View File

@ -33,12 +33,12 @@ import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HexFormat; import java.util.HexFormat;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -61,6 +61,8 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.ServerPublicParams; import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
@ -124,10 +126,10 @@ class ProfileControllerTest {
private static final ServerZkProfileOperations zkProfileOperations = mock(ServerZkProfileOperations.class); private static final ServerZkProfileOperations zkProfileOperations = mock(ServerZkProfileOperations.class);
private static final byte[] UNIDENTIFIED_ACCESS_KEY = "test-uak".getBytes(StandardCharsets.UTF_8); private static final byte[] UNIDENTIFIED_ACCESS_KEY = "test-uak".getBytes(StandardCharsets.UTF_8);
private static final byte[] ACCOUNT_IDENTITY_KEY = "barz".getBytes(StandardCharsets.UTF_8); private static final IdentityKey ACCOUNT_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] ACCOUNT_PHONE_NUMBER_IDENTITY_KEY = "bazz".getBytes(StandardCharsets.UTF_8); private static final IdentityKey ACCOUNT_PHONE_NUMBER_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] ACCOUNT_TWO_IDENTITY_KEY = "bar".getBytes(StandardCharsets.UTF_8); private static final IdentityKey ACCOUNT_TWO_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final byte[] ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY = "baz".getBytes(StandardCharsets.UTF_8); private static final IdentityKey ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY = new IdentityKey(Curve.generateKeyPair().getPublicKey());
private static final String BASE_64_URL_USERNAME_HASH = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE"; private static final String BASE_64_URL_USERNAME_HASH = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE";
private static final byte[] USERNAME_HASH = Base64.getUrlDecoder().decode(BASE_64_URL_USERNAME_HASH); private static final byte[] USERNAME_HASH = Base64.getUrlDecoder().decode(BASE_64_URL_USERNAME_HASH);
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -1170,26 +1172,31 @@ class ProfileControllerTest {
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> { final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> {
if (AuthHelper.VALID_UUID.equals(element.aci())) { if (AuthHelper.VALID_UUID.equals(element.aci())) {
return Arrays.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); return Objects.equals(ACCOUNT_IDENTITY_KEY, element.identityKey());
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) {
return Arrays.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); return Objects.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey());
} else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) { } else if (AuthHelper.VALID_UUID_TWO.equals(element.uuid())) {
return Arrays.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey()); return Objects.equals(ACCOUNT_TWO_IDENTITY_KEY, element.identityKey());
} else { } else {
return false; return false;
} }
}, "is an expected UUID with the correct identity key"); }, "is an expected UUID with the correct identity key");
final IdentityKey validAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final IdentityKey secondValidPniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final IdentityKey secondValidAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final IdentityKey invalidAciIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.json(new BatchIdentityCheckRequest(List.of( .post(Entity.json(new BatchIdentityCheckRequest(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null,
convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO,
convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO, new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_UUID_TWO,
convertKeyToFingerprint("another2".getBytes(StandardCharsets.UTF_8))), convertKeyToFingerprint(secondValidAciIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null,
convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8))) convertKeyToFingerprint(invalidAciIdentityKey))
))))) { ))))) {
assertThat(response).isNotNull(); assertThat(response).isNotNull();
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
@ -1202,13 +1209,13 @@ class ProfileControllerTest {
} }
final List<BatchIdentityCheckRequest.Element> largeElementList = new ArrayList<>(List.of( final List<BatchIdentityCheckRequest.Element> largeElementList = new ArrayList<>(List.of(
new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), new BatchIdentityCheckRequest.Element(AuthHelper.VALID_UUID, null, convertKeyToFingerprint(validAciIdentityKey)),
new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), new BatchIdentityCheckRequest.Element(null, AuthHelper.VALID_PNI_TWO, convertKeyToFingerprint(secondValidPniIdentityKey)),
new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8))))); new BatchIdentityCheckRequest.Element(AuthHelper.INVALID_UUID, null, convertKeyToFingerprint(invalidAciIdentityKey))));
for (int i = 0; i < 900; i++) { for (int i = 0; i < 900; i++) {
largeElementList.add( largeElementList.add(
new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint("abcd".getBytes(StandardCharsets.UTF_8)))); new BatchIdentityCheckRequest.Element(UUID.randomUUID(), null, convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
} }
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
@ -1228,9 +1235,9 @@ class ProfileControllerTest {
final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> { final Condition<BatchIdentityCheckResponse.Element> isAnExpectedUuid = new Condition<>(element -> {
if (AuthHelper.VALID_UUID.equals(element.aci())) { if (AuthHelper.VALID_UUID.equals(element.aci())) {
return Arrays.equals(ACCOUNT_IDENTITY_KEY, element.identityKey()); return ACCOUNT_IDENTITY_KEY.equals(element.identityKey());
} else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) { } else if (AuthHelper.VALID_PNI_TWO.equals(element.uuid())) {
return Arrays.equals(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY, element.identityKey()); return ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY.equals(element.identityKey());
} else { } else {
return false; return false;
} }
@ -1245,9 +1252,9 @@ class ProfileControllerTest {
{ "aci": "%s", "fingerprint": "%s" } { "aci": "%s", "fingerprint": "%s" }
] ]
} }
""", AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))), """, AuthHelper.VALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint("another1".getBytes(StandardCharsets.UTF_8))), AuthHelper.VALID_PNI_TWO, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))),
AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint("456".getBytes(StandardCharsets.UTF_8)))); AuthHelper.INVALID_UUID, Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey()))));
try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request() try (final Response response = resources.getJerseyTest().target("/v1/profile/identity_check/batch").request()
.post(Entity.entity(json, "application/json"))) { .post(Entity.entity(json, "application/json"))) {
@ -1313,15 +1320,15 @@ class ProfileControllerTest {
] ]
} }
""", AuthHelper.VALID_UUID, AuthHelper.VALID_PNI, """, AuthHelper.VALID_UUID, AuthHelper.VALID_PNI,
Base64.getEncoder().encodeToString(convertKeyToFingerprint("else1234".getBytes(StandardCharsets.UTF_8))))) Base64.getEncoder().encodeToString(convertKeyToFingerprint(new IdentityKey(Curve.generateKeyPair().getPublicKey())))))
); );
} }
private static byte[] convertKeyToFingerprint(byte[] key) { private static byte[] convertKeyToFingerprint(final IdentityKey publicKey) {
try { try {
return Util.truncate(MessageDigest.getInstance("SHA-256").digest(key), 4); return Util.truncate(MessageDigest.getInstance("SHA-256").digest(publicKey.serialize()), 4);
} catch (NoSuchAlgorithmException e) { } catch (final NoSuchAlgorithmException e) {
throw new AssertionError(e); throw new AssertionError("All Java implementations must support SHA-256 MessageDigest algorithm", e);
} }
} }
} }

View File

@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
@ -415,8 +416,8 @@ class RegistrationControllerTest {
} }
static Stream<Arguments> atomicAccountCreationConflictingChannel() { static Stream<Arguments> atomicAccountCreationConflictingChannel() {
final Optional<byte[]> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<byte[]> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<SignedPreKey> aciPqLastResortPreKey;
@ -425,8 +426,8 @@ class RegistrationControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
@ -504,8 +505,8 @@ class RegistrationControllerTest {
} }
static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() { static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
final Optional<byte[]> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<byte[]> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<SignedPreKey> aciPqLastResortPreKey;
@ -514,8 +515,8 @@ class RegistrationControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
@ -617,8 +618,8 @@ class RegistrationControllerTest {
@MethodSource @MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final byte[] expectedAciIdentityKey, final IdentityKey expectedAciIdentityKey,
final byte[] expectedPniIdentityKey, final IdentityKey expectedPniIdentityKey,
final SignedPreKey expectedAciSignedPreKey, final SignedPreKey expectedAciSignedPreKey,
final SignedPreKey expectedPniSignedPreKey, final SignedPreKey expectedPniSignedPreKey,
final SignedPreKey expectedAciPqLastResortPreKey, final SignedPreKey expectedAciPqLastResortPreKey,
@ -683,8 +684,8 @@ class RegistrationControllerTest {
} }
private static Stream<Arguments> atomicAccountCreationSuccess() { private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<byte[]> aciIdentityKey; final Optional<IdentityKey> aciIdentityKey;
final Optional<byte[]> pniIdentityKey; final Optional<IdentityKey> pniIdentityKey;
final Optional<SignedPreKey> aciSignedPreKey; final Optional<SignedPreKey> aciSignedPreKey;
final Optional<SignedPreKey> pniSignedPreKey; final Optional<SignedPreKey> pniSignedPreKey;
final Optional<SignedPreKey> aciPqLastResortPreKey; final Optional<SignedPreKey> aciPqLastResortPreKey;
@ -693,8 +694,8 @@ class RegistrationControllerTest {
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciIdentityKey = Optional.of(aciIdentityKeyPair.getPublicKey().serialize()); aciIdentityKey = Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
pniIdentityKey = Optional.of(pniIdentityKeyPair.getPublicKey().serialize()); pniIdentityKey = Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));

View File

@ -5,7 +5,9 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -21,6 +23,7 @@ import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -155,7 +158,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
@ -172,7 +175,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
assertArrayEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey()); assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey());
assertEquals(OptionalInt.of(rotatedPniRegistrationId), assertEquals(OptionalInt.of(rotatedPniRegistrationId),
updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentityRegistrationId()); updatedAccount.getMasterDevice().orElseThrow().getPhoneNumberIdentityRegistrationId());

View File

@ -20,7 +20,6 @@ import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
@ -38,6 +37,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
@ -147,7 +148,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
final boolean discoverableByPhoneNumber = false; final boolean discoverableByPhoneNumber = false;
final String currentProfileVersion = "cpv"; final String currentProfileVersion = "cpv";
final byte[] identityKey = "ikey".getBytes(StandardCharsets.UTF_8); final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final byte[] unidentifiedAccessKey = new byte[]{1}; final byte[] unidentifiedAccessKey = new byte[]{1};
final String pin = "1234"; final String pin = "1234";
final String registrationLock = "reglock"; final String registrationLock = "reglock";
@ -189,12 +190,12 @@ class AccountsManagerConcurrentModificationIntegrationTest {
return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class); return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class);
} }
private void verifyAccount(final String name, final Account account, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final byte[] identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAccess, final long lastSeen) { private void verifyAccount(final String name, final Account account, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final IdentityKey identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAccess, final long lastSeen) {
assertAll(name, assertAll(name,
() -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()), () -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()),
() -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().orElseThrow()), () -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().orElseThrow()),
() -> assertArrayEquals(identityKey, account.getIdentityKey()), () -> assertEquals(identityKey, account.getIdentityKey()),
() -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().orElseThrow()), () -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().orElseThrow()),
() -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock)), () -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock)),
() -> assertEquals(unrestrictedUnidentifiedAccess, account.isUnrestrictedUnidentifiedAccess()) () -> assertEquals(unrestrictedUnidentifiedAccess, account.isUnrestrictedUnidentifiedAccess())

View File

@ -5,7 +5,12 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
@ -24,8 +29,6 @@ import static org.mockito.Mockito.when;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
@ -46,6 +49,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -473,10 +477,12 @@ class AccountsManagerTest {
.doAnswer(ACCOUNT_UPDATE_ANSWER) .doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any()); .when(accounts).update(any());
account = accountsManager.update(account, a -> a.setIdentityKey("identity-key".getBytes(StandardCharsets.UTF_8))); final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
account = accountsManager.update(account, a -> a.setIdentityKey(identityKey));
assertEquals(1, account.getVersion()); assertEquals(1, account.getVersion());
assertArrayEquals("identity-key".getBytes(StandardCharsets.UTF_8), account.getIdentityKey()); assertEquals(identityKey, account.getIdentityKey());
verify(accounts, times(1)).getByAccountIdentifier(uuid); verify(accounts, times(1)).getByAccountIdentifier(uuid);
verify(accounts, times(2)).update(any()); verify(accounts, times(2)).update(any());
@ -669,7 +675,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber( () -> accountsManager.changeNumber(
account, number, "new-identity-key".getBytes(StandardCharsets.UTF_8), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any()); verify(accounts, never()).update(any());
@ -728,7 +734,7 @@ class AccountsManagerTest {
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
final Account updatedAccount = accountsManager.changeNumber( final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, newSignedPqKeys, newRegistrationIds); account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
assertEquals(targetNumber, updatedAccount.getNumber()); assertEquals(targetNumber, updatedAccount.getNumber());
@ -771,7 +777,9 @@ class AccountsManagerTest {
UUID oldPni = account.getPhoneNumberIdentifier(); UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final Account updatedAccount = accountsManager.updatePniKeys(account, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, null, newRegistrationIds); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount = accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, null, newRegistrationIds);
// non-PNI stuff should not change // non-PNI stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid()); assertEquals(oldUuid, updatedAccount.getUuid());
@ -783,7 +791,7 @@ class AccountsManagerTest {
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI stuff should // PNI stuff should
assertArrayEquals("new-pni-identity-key".getBytes(StandardCharsets.UTF_8), updatedAccount.getPhoneNumberIdentityKey()); assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey());
assertEquals(newSignedKeys, assertEquals(newSignedKeys,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey)));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,
@ -817,8 +825,10 @@ class AccountsManagerTest {
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount = final Account updatedAccount =
accountsManager.updatePniKeys(account, "new-pni-identity-key".getBytes(StandardCharsets.UTF_8), newSignedKeys, newSignedPqKeys, newRegistrationIds); accountsManager.updatePniKeys(account, pniIdentityKey, newSignedKeys, newSignedPqKeys, newRegistrationIds);
// non-PNI-keys stuff should not change // non-PNI-keys stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid()); assertEquals(oldUuid, updatedAccount.getUuid());
@ -830,7 +840,7 @@ class AccountsManagerTest {
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should // PNI keys should
assertArrayEquals("new-pni-identity-key".getBytes(StandardCharsets.UTF_8), updatedAccount.getPhoneNumberIdentityKey()); assertEquals(pniIdentityKey, updatedAccount.getPhoneNumberIdentityKey());
assertEquals(newSignedKeys, assertEquals(newSignedKeys,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey)));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,

View File

@ -14,7 +14,6 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
@ -27,6 +26,8 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
@ -106,7 +107,7 @@ public class ChangeNumberManagerTest {
Account account = mock(Account.class); Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
var prekeys = Map.of(1L, new SignedPreKey()); var prekeys = Map.of(1L, new SignedPreKey());
final byte[] pniIdentityKey = "pni-identity-key".getBytes(StandardCharsets.UTF_8); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
@ -132,7 +133,7 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final byte[] pniIdentityKey = "pni-identity-key".getBytes(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
@ -175,7 +176,7 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final byte[] pniIdentityKey = "pni-identity-key".getBytes(StandardCharsets.UTF_8); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
@ -217,7 +218,7 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final byte[] pniIdentityKey = "pni-identity-key".getBytes(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
@ -257,7 +258,7 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final byte[] pniIdentityKey = "pni-identity-key".getBytes(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
@ -296,7 +297,7 @@ public class ChangeNumberManagerTest {
when(account.getDevice(2L)).thenReturn(Optional.of(d2)); when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2)); when(account.getDevices()).thenReturn(List.of(d2));
final byte[] pniIdentityKey = "pni-identity-key".getBytes(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
@ -347,7 +348,7 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key".getBytes(StandardCharsets.UTF_8), preKeys, null, messages, registrationIds)); () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
} }
@Test @Test
@ -377,7 +378,7 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, "pni-identity-key".getBytes(StandardCharsets.UTF_8), preKeys, null, messages, registrationIds)); () -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
} }
@Test @Test
@ -406,6 +407,6 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key".getBytes(StandardCharsets.UTF_8), null, null, messages, registrationIds)); () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), null, null, messages, registrationIds));
} }
} }

View File

@ -113,7 +113,7 @@ class CertificateControllerTest {
assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderDevice(), 1L);
assertTrue(certificate.hasSenderUuid()); assertTrue(certificate.hasSenderUuid());
assertEquals(AuthHelper.VALID_UUID.toString(), certificate.getSenderUuid()); assertEquals(AuthHelper.VALID_UUID.toString(), certificate.getSenderUuid());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize());
} }
@Test @Test
@ -141,7 +141,7 @@ class CertificateControllerTest {
assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER); assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER);
assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString()); assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize());
} }
@Test @Test
@ -170,7 +170,7 @@ class CertificateControllerTest {
assertTrue(StringUtils.isBlank(certificate.getSender())); assertTrue(StringUtils.isBlank(certificate.getSender()));
assertEquals(certificate.getSenderDevice(), 1L); assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString()); assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY); assertArrayEquals(certificate.getIdentityKey().toByteArray(), AuthHelper.VALID_IDENTITY.serialize());
} }
@Test @Test

View File

@ -42,6 +42,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -277,8 +278,8 @@ class DeviceControllerTest {
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(aciIdentityKeyPair.getPublicKey().serialize()); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKeyPair.getPublicKey().serialize()); when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901", final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null), new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
@ -363,8 +364,8 @@ class DeviceControllerTest {
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(aciIdentityKeyPair.getPublicKey().serialize()); when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKeyPair.getPublicKey().serialize()); when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest("5678901", final LinkDeviceRequest request = new LinkDeviceRequest("5678901",
new AccountAttributes(fetchesMessages, 1234, null, null, true, null), new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
@ -392,8 +393,8 @@ class DeviceControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicMissingProperty(final byte[] aciIdentityKey, void linkDeviceAtomicMissingProperty(final IdentityKey aciIdentityKey,
final byte[] pniIdentityKey, final IdentityKey pniIdentityKey,
final Optional<SignedPreKey> aciSignedPreKey, final Optional<SignedPreKey> aciSignedPreKey,
final Optional<SignedPreKey> pniSignedPreKey, final Optional<SignedPreKey> pniSignedPreKey,
final Optional<SignedPreKey> aciPqLastResortPreKey, final Optional<SignedPreKey> aciPqLastResortPreKey,
@ -439,8 +440,8 @@ class DeviceControllerTest {
final Optional<SignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); final Optional<SignedPreKey> aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
final Optional<SignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); final Optional<SignedPreKey> pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final byte[] aciIdentityKey = aciIdentityKeyPair.getPublicKey().serialize(); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of( return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, Optional.empty(), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),
@ -452,8 +453,8 @@ class DeviceControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void linkDeviceAtomicInvalidSignature(final byte[] aciIdentityKey, void linkDeviceAtomicInvalidSignature(final IdentityKey aciIdentityKey,
final byte[] pniIdentityKey, final IdentityKey pniIdentityKey,
final SignedPreKey aciSignedPreKey, final SignedPreKey aciSignedPreKey,
final SignedPreKey pniSignedPreKey, final SignedPreKey pniSignedPreKey,
final SignedPreKey aciPqLastResortPreKey, final SignedPreKey aciPqLastResortPreKey,
@ -499,8 +500,8 @@ class DeviceControllerTest {
final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); final SignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); final SignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
final byte[] aciIdentityKey = aciIdentityKeyPair.getPublicKey().serialize(); final IdentityKey aciIdentityKey = new IdentityKey(aciIdentityKeyPair.getPublicKey());
final byte[] pniIdentityKey = pniIdentityKeyPair.getPublicKey().serialize(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
return Stream.of( return Stream.of(
Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), Arguments.of(aciIdentityKey, pniIdentityKey, signedPreKeyWithBadSignature(aciSignedPreKey), pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey),

View File

@ -39,6 +39,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -80,10 +81,10 @@ class KeysControllerTest {
private static final int SAMPLE_PNI_REGISTRATION_ID = 1717; private static final int SAMPLE_PNI_REGISTRATION_ID = 1717;
private final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final byte[] IDENTITY_KEY = IDENTITY_KEY_PAIR.getPublicKey().serialize(); private final IdentityKey IDENTITY_KEY = new IdentityKey(IDENTITY_KEY_PAIR.getPublicKey());
private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final byte[] PNI_IDENTITY_KEY = PNI_IDENTITY_KEY_PAIR.getPublicKey().serialize(); private final IdentityKey PNI_IDENTITY_KEY = new IdentityKey(PNI_IDENTITY_KEY_PAIR.getPublicKey());
private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234); private final PreKey SAMPLE_KEY = KeysHelper.ecPreKey(1234);
private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667); private final PreKey SAMPLE_KEY2 = KeysHelper.ecPreKey(5667);
@ -658,7 +659,7 @@ class KeysControllerTest {
final PreKey preKey = KeysHelper.ecPreKey(31337); final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@ -688,7 +689,7 @@ class KeysControllerTest {
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@ -716,7 +717,7 @@ class KeysControllerTest {
@Test @Test
void putKeysStructurallyInvalidSignedECKey() { void putKeysStructurallyInvalidSignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair); final SignedPreKey wrongPreKey = KeysHelper.signedKEMPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null); final PreKeyState preKeyState = new PreKeyState(identityKey, wrongPreKey, null, null, null);
@ -729,11 +730,11 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
} }
@Test @Test
void putKeysStructurallyInvalidUnsignedECKey() { void putKeysStructurallyInvalidUnsignedECKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes()); final PreKey wrongPreKey = new PreKey(1, "cluck cluck i'm a parrot".getBytes());
final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null); final PreKeyState preKeyState = new PreKeyState(identityKey, null, List.of(wrongPreKey), null, null);
@ -746,11 +747,11 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
} }
@Test @Test
void putKeysStructurallyInvalidPQOneTimeKey() { void putKeysStructurallyInvalidPQOneTimeKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null); final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, List.of(wrongPreKey), null);
@ -763,11 +764,11 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
} }
@Test @Test
void putKeysStructurallyInvalidPQLastResortKey() { void putKeysStructurallyInvalidPQLastResortKey() {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair); final SignedPreKey wrongPreKey = KeysHelper.signedECPreKey(1, identityKeyPair);
final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey); final PreKeyState preKeyState = new PreKeyState(identityKey, null, null, null, wrongPreKey);
@ -780,13 +781,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
} }
@Test @Test
void putKeysByPhoneNumberIdentifierTestV2() { void putKeysByPhoneNumberIdentifierTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
@ -817,7 +818,7 @@ class KeysControllerTest {
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair); final SignedPreKey pqPreKey = KeysHelper.signedKEMPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair); final SignedPreKey pqLastResortPreKey = KeysHelper.signedKEMPreKey(31340, identityKeyPair);
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
@ -860,10 +861,10 @@ class KeysControllerTest {
@Test @Test
void disabledPutKeysTestV2() { void disabledPutKeysTestV2() {
final PreKey preKey = KeysHelper.ecPreKey(31337); final PreKey preKey = KeysHelper.ecPreKey(31337);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedECPreKey(31338, identityKeyPair);
final byte[] identityKey = identityKeyPair.getPublicKey().serialize(); final IdentityKey identityKey = new IdentityKey(identityKeyPair.getPublicKey());
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));

View File

@ -20,6 +20,8 @@ import java.util.Base64;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.UUID; import java.util.UUID;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
@ -63,7 +65,8 @@ public class AuthHelper {
public static final UUID UNDISCOVERABLE_UUID = UUID.randomUUID(); public static final UUID UNDISCOVERABLE_UUID = UUID.randomUUID();
public static final String UNDISCOVERABLE_PASSWORD = "IT'S A SECRET TO EVERYBODY."; public static final String UNDISCOVERABLE_PASSWORD = "IT'S A SECRET TO EVERYBODY.";
public static final byte[] VALID_IDENTITY = Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo"); public static final IdentityKey VALID_IDENTITY = new IdentityKey(ECPublicKey.fromPublicKeyBytes(
Base64.getDecoder().decode("BcxxDU9FGMda70E7+Uvm7pnQcEdXQ64aJCpPUeRSfcFo")));
public static AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); public static AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
public static Account VALID_ACCOUNT = mock(Account.class ); public static Account VALID_ACCOUNT = mock(Account.class );