diff --git a/service/config/sample.yml b/service/config/sample.yml index 5ad344115..32e03572a 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -53,8 +53,12 @@ dynamoDbTables: tableName: Example_IssuedReceipts expiration: P30D # Duration of time until rows expire generator: abcdefg12345678= # random base64-encoded binary sequence - keys: + ecKeys: tableName: Example_Keys + pqKeys: + tableName: Example_PQ_Keys + pqLastResortKeys: + tableName: Example_PQ_Last_Resort_Keys messages: tableName: Example_Messages expiration: P30D # Duration of time until rows expire diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 1bd99a891..91ead2faf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -341,7 +341,10 @@ public class WhisperServerService extends Application identityType) { - int count = keys.getCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); + int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); - return new PreKeyCount(count); + return new PreKeyCount(ecCount, pqCount); } @Timed @@ -88,9 +97,17 @@ public class KeysController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) @ChangesDeviceEnabledState + @Operation(summary = "Sets the identity key for the account or phone-number identity and/or prekeys for this device") public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth, - @NotNull @Valid final PreKeyState preKeys, + @RequestBody @NotNull @Valid final PreKeyState preKeys, + + @Parameter(allowEmptyValue=true) + @Schema( + allowableValues={"aci", "pni"}, + defaultValue="aci", + description="whether this operation applies to the account (aci) or phone-number (pni) identity") @QueryParam("identity") final Optional identityType, + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { Account account = disabledPermittedAuth.getAccount(); Device device = disabledPermittedAuth.getAuthenticatedDevice(); @@ -98,7 +115,8 @@ public class KeysController { final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType); - if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) { + if (preKeys.getSignedPreKey() != null && + !preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) { updateAccount = true; } @@ -121,13 +139,15 @@ public class KeysController { if (updateAccount) { account = accounts.update(account, a -> { - a.getDevice(device.getId()).ifPresent(d -> { - if (usePhoneNumberIdentity) { - d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey()); - } else { - d.setSignedPreKey(preKeys.getSignedPreKey()); - } - }); + if (preKeys.getSignedPreKey() != null) { + a.getDevice(device.getId()).ifPresent(d -> { + if (usePhoneNumberIdentity) { + d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey()); + } else { + d.setSignedPreKey(preKeys.getSignedPreKey()); + } + }); + } if (usePhoneNumberIdentity) { a.setPhoneNumberIdentityKey(preKeys.getIdentityKey()); @@ -137,17 +157,29 @@ public class KeysController { }); } - keys.store(getIdentifier(account, identityType), device.getId(), preKeys.getPreKeys()); + keys.store( + getIdentifier(account, identityType), device.getId(), + preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getPqLastResortPreKey()); } @Timed @GET @Path("/{identifier}/{device_id}") @Produces(MediaType.APPLICATION_JSON) + @Operation(summary = "Retrieves the public identity key and available device prekeys for a specified account or phone-number identity") public Response getDeviceKeys(@Auth Optional auth, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + + @Parameter(description="the account or phone-number identifier to retrieve keys for") @PathParam("identifier") UUID targetUuid, + + @Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices") @PathParam("device_id") String deviceId, + + @Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys") + @Schema(defaultValue="false") + @QueryParam("pq") boolean returnPqKey, + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { @@ -175,28 +207,30 @@ public class KeysController { final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid); - Map preKeysByDeviceId = getLocalKeys(target, deviceId, usePhoneNumberIdentity); - List responseItems = new LinkedList<>(); + List devices = parseDeviceId(deviceId, target); + List responseItems = new ArrayList<>(devices.size()); - for (Device device : target.getDevices()) { - if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) { - SignedPreKey signedPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); - PreKey preKey = preKeysByDeviceId.get(device.getId()); + for (Device device : devices) { + UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid; + SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); + PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null); + SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null; - if (signedPreKey != null || preKey != null) { - final int registrationId = usePhoneNumberIdentity ? - device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : - device.getRegistrationId(); + if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { + final int registrationId = usePhoneNumberIdentity ? + device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : + device.getRegistrationId(); - responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey)); - } + responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey)); } } final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); - if (responseItems.isEmpty()) return Response.status(404).build(); - else return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build(); + if (responseItems.isEmpty()) { + return Response.status(404).build(); + } + return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build(); } @Timed @@ -243,31 +277,15 @@ public class KeysController { account.getUuid(); } - private Map getLocalKeys(Account destination, String deviceIdSelector, final boolean usePhoneNumberIdentity) { - final Map preKeys; - - final UUID identifier = usePhoneNumberIdentity ? - destination.getPhoneNumberIdentifier() : - destination.getUuid(); - - if (deviceIdSelector.equals("*")) { - preKeys = new HashMap<>(); - - for (final Device device : destination.getDevices()) { - keys.take(identifier, device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey)); - } - } else { - try { - long deviceId = Long.parseLong(deviceIdSelector); - - preKeys = keys.take(identifier, deviceId) - .map(preKey -> Map.of(deviceId, preKey)) - .orElse(Collections.emptyMap()); - } catch (NumberFormatException e) { - throw new WebApplicationException(Response.status(422).build()); - } + private List parseDeviceId(String deviceId, Account account) { + if (deviceId.equals("*")) { + return account.getDevices().stream().filter(Device::isEnabled).toList(); + } + try { + long id = Long.parseLong(deviceId); + return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of()); + } catch (NumberFormatException e) { + throw new WebApplicationException(Response.status(422).build()); } - - return preKeys; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index 153c4e54a..0593289bb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -16,21 +18,57 @@ import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotNull; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; -public record ChangeNumberRequest(String sessionId, - @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, - @NotBlank String number, - @JsonProperty("reglock") @Nullable String registrationLock, - @NotBlank String pniIdentityKey, - @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, - @NotNull @Valid Map devicePniSignedPrekeys, - @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { +public record ChangeNumberRequest( + @Schema(description=""" + A session ID from registration service, if using session id to authenticate this request. + Must not be combined with `recoveryPassword`.""") + String sessionId, + + @Schema(description=""" + The recovery password for the new phone number, if using a recovery password to authenticate this request. + Must not be combined with `sessionId`.""") + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, + + @Schema(description="the new phone number for this account") + @NotBlank String number, + + @Schema(description="the registration lock password for the new phone number, if necessary") + @JsonProperty("reglock") @Nullable String registrationLock, + + @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") + @NotBlank String pniIdentityKey, + + @Schema(description=""" + A list of synchronization messages to send to companion devices to supply the private keys + associated with the new identity key and their new prekeys. + Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") + @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, + + @Schema(description=""" + A new signed elliptic-curve prekey for each enabled device on the account, including this one. + Each must be accompanied by a valid signature from the new identity key in this request.""") + @NotNull @Valid Map devicePniSignedPrekeys, + + @Schema(description=""" + A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. + May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. + If present, must contain one prekey per enabled device including this one. + Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. + Each must be accompanied by a valid signature from the new identity key in this request.""") + @Valid Map devicePniPqLastResortPrekeys, + + @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") + @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - if (devicePniSignedPrekeys == null) { - return true; + List spks = new ArrayList<>(); + if (devicePniSignedPrekeys != null) { + spks.addAll(devicePniSignedPrekeys.values()); } - return devicePniSignedPrekeys.values().parallelStream() - .allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); + if (devicePniPqLastResortPrekeys != null) { + spks.addAll(devicePniPqLastResortPrekeys.values()); + } + return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index 0a3efb99a..9ccf3bb9d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -6,27 +6,61 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.validation.constraints.AssertTrue; import javax.annotation.Nullable; +import javax.validation.Valid; import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; -public record ChangePhoneNumberRequest(@NotBlank String number, - @NotBlank String code, - @JsonProperty("reglock") @Nullable String registrationLock, - @Nullable String pniIdentityKey, - @Nullable List deviceMessages, - @Nullable Map devicePniSignedPrekeys, - @Nullable Map pniRegistrationIds) { +public record ChangePhoneNumberRequest( + @Schema(description="the new phone number for this account") + @NotBlank String number, + + @Schema(description="the registration verification code to authenticate this request") + @NotBlank String code, + + @Schema(description="the registration lock password for the new phone number, if necessary") + @JsonProperty("reglock") @Nullable String registrationLock, + + @Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number") + @Nullable String pniIdentityKey, + + @Schema(description=""" + A list of synchronization messages to send to companion devices to supply the private keys + associated with the new identity key and their new prekeys. + Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") + @Nullable List deviceMessages, + + @Schema(description=""" + A new signed elliptic-curve prekey for each enabled device on the account, including this one. + Each must be accompanied by a valid signature from the new identity key in this request.""") + @Nullable Map devicePniSignedPrekeys, + + @Schema(description=""" + A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. + May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. + If present, must contain one prekey per enabled device including this one. + Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. + Each must be accompanied by a valid signature from the new identity key in this request.""") + @Nullable @Valid Map devicePniPqLastResortPrekeys, + + @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") + @Nullable Map pniRegistrationIds) { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - if (devicePniSignedPrekeys == null) { - return true; + List spks = new ArrayList<>(); + if (devicePniSignedPrekeys != null) { + spks.addAll(devicePniSignedPrekeys.values()); } - return devicePniSignedPrekeys.values().parallelStream() - .allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); + if (devicePniPqLastResortPrekeys != null) { + spks.addAll(devicePniPqLastResortPrekeys.values()); + } + return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index 619143fe3..16438c68b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -17,29 +18,45 @@ import javax.validation.constraints.NotNull; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; public record PhoneNumberIdentityKeyDistributionRequest( - @NotBlank - @Schema(description="the new identity key for this account's phone-number identity") - String pniIdentityKey, + @NotBlank + @Schema(description="the new identity key for this account's phone-number identity") + String pniIdentityKey, - @NotNull - @Valid - @Schema(description="A message for each companion device to pass its new private keys") - List<@NotNull @Valid IncomingMessage> deviceMessages, + @NotNull + @Valid + @Schema(description=""" + A list of synchronization messages to send to companion devices to supply the private keys + associated with the new identity key and their new prekeys. + Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") + List<@NotNull @Valid IncomingMessage> deviceMessages, - @NotNull - @Valid - @Schema(description="The public key of a new signed elliptic-curve prekey pair for each device") - Map devicePniSignedPrekeys, + @NotNull + @Valid + @Schema(description=""" + A new signed elliptic-curve prekey for each enabled device on the account, including this one. + Each must be accompanied by a valid signature from the new identity key in this request.""") + Map devicePniSignedPrekeys, - @NotNull - @Valid - @Schema(description="The new registration ID to use for the phone-number identity of each device") - Map pniRegistrationIds) { + @Schema(description=""" + A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. + May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. + If present, must contain one prekey per enabled device including this one. + Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. + Each must be accompanied by a valid signature from the new identity key in this request.""") + @Valid Map devicePniPqLastResortPrekeys, + + @NotNull + @Valid + @Schema(description="The new registration ID to use for the phone-number identity of each device") + Map pniRegistrationIds) { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { - return devicePniSignedPrekeys.values().parallelStream() - .allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); + List spks = new ArrayList<>(devicePniSignedPrekeys.values()); + if (devicePniPqLastResortPrekeys != null) { + spks.addAll(devicePniPqLastResortPrekeys.values()); + } + return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java index d24dc509f..2d8335504 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKey.java @@ -13,17 +13,17 @@ public class PreKey { @JsonProperty @NotNull - private long keyId; + private long keyId; @JsonProperty @NotEmpty - private String publicKey; + private String publicKey; public PreKey() {} public PreKey(long keyId, String publicKey) { - this.keyId = keyId; + this.keyId = keyId; this.publicKey = publicKey; } @@ -63,5 +63,4 @@ public class PreKey { return ((int)this.keyId) ^ publicKey.hashCode(); } } - } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyCount.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyCount.java index 27df671c3..17bcee889 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyCount.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyCount.java @@ -5,16 +5,22 @@ package org.whispersystems.textsecuregcm.entities; - import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; public class PreKeyCount { + @Schema(description="the number of stored unsigned elliptic-curve prekeys for this device") @JsonProperty private int count; - public PreKeyCount(int count) { - this.count = count; + @Schema(description="the number of stored one-time post-quantum prekeys for this device") + @JsonProperty + private int pqCount; + + public PreKeyCount(int ecCount, int pqCount) { + this.count = ecCount; + this.pqCount = pqCount; } public PreKeyCount() {} @@ -22,4 +28,8 @@ public class PreKeyCount { public int getCount() { return count; } + + public int getPqCount() { + return pqCount; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java index 82cf7ad91..0c983c244 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java @@ -7,15 +7,18 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; +import io.swagger.v3.oas.annotations.media.Schema; import java.util.List; public class PreKeyResponse { @JsonProperty + @Schema(description="the public identity key for the requested identity") private String identityKey; @JsonProperty + @Schema(description="information about each requested device") private List devices; public PreKeyResponse() {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java index 6e9d6bd1f..42491b09f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java @@ -6,28 +6,39 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; +import io.swagger.v3.oas.annotations.media.Schema; public class PreKeyResponseItem { @JsonProperty + @Schema(description="the device ID of the device to which this item pertains") private long deviceId; @JsonProperty + @Schema(description="the registration ID for the device") private int registrationId; @JsonProperty + @Schema(description="the signed elliptic-curve prekey for the device, if one has been set") private SignedPreKey signedPreKey; @JsonProperty + @Schema(description="an unsigned elliptic-curve prekey for the device, if any remain") private PreKey preKey; + @JsonProperty + @Schema(description="a signed post-quantum prekey for the device " + + "(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)") + private SignedPreKey pqPreKey; + public PreKeyResponseItem() {} - public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey) { - this.deviceId = deviceId; + public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) { + this.deviceId = deviceId; this.registrationId = registrationId; - this.signedPreKey = signedPreKey; - this.preKey = preKey; + this.signedPreKey = signedPreKey; + this.preKey = preKey; + this.pqPreKey = pqPreKey; } @VisibleForTesting @@ -40,6 +51,11 @@ public class PreKeyResponseItem { return preKey; } + @VisibleForTesting + public SignedPreKey getPqPreKey() { + return pqPreKey; + } + @VisibleForTesting public int getRegistrationId() { return registrationId; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java index 16db6085e..ae75a19ad 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeySignatureValidator.java @@ -5,24 +5,38 @@ package org.whispersystems.textsecuregcm.entities; import static com.codahale.metrics.MetricRegistry.name; + +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.util.Base64; +import java.util.Collection; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECPublicKey; public abstract class PreKeySignatureValidator { - public static final boolean validatePreKeySignature(final String identityKeyB64, final SignedPreKey spk) { + public static final Counter INVALID_SIGNATURE_COUNTER = + Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")); + + public static final boolean validatePreKeySignatures(final String identityKeyB64, final Collection spks) { try { final byte[] identityKeyBytes = Base64.getDecoder().decode(identityKeyB64); - final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey()); - final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature()); - final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0); - return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes); + final boolean success = spks.stream().allMatch(spk -> { + final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey()); + final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature()); + + return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes); + }); + + if (!success) { + INVALID_SIGNATURE_COUNTER.increment(); + } + + return success; } catch (IllegalArgumentException | InvalidKeyException e) { - Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")).increment(); + INVALID_SIGNATURE_COUNTER.increment(); return false; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java index b68c64f4f..04287341e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyState.java @@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; import java.util.List; import javax.validation.Valid; import javax.validation.constraints.AssertTrue; @@ -15,26 +17,59 @@ import javax.validation.constraints.NotNull; public class PreKeyState { @JsonProperty - @NotNull @Valid + @Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " + + "If present and not empty, replaces all stored unsigned EC prekeys for the device; " + + "if absent or empty, any stored unsigned EC prekeys for the device are not deleted.") private List preKeys; @JsonProperty - @NotNull @Valid + @Schema(description="An optional signed elliptic-curve prekey to use for this device. " + + "If present, replaces the stored signed elliptic-curve prekey for the device; " + + "if absent, the stored signed prekey is not deleted. " + + "If present, must have a valid signature from the identity key in this request.") private SignedPreKey signedPreKey; + @JsonProperty + @Valid + @Schema(description="A list of signed post-quantum one-time prekeys to use for this device. " + + "Each key must have a valid signature from the identity key in this request. " + + "If present and not empty, replaces all stored unsigned PQ prekeys for the device; " + + "if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.") + private List pqPreKeys; + + @JsonProperty + @Valid + @Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " + + "If present, replaces the stored signed post-quantum last-resort prekey for the device; " + + "if absent, a stored last-resort prekey will *not* be deleted. " + + "If present, must have a valid signature from the identity key in this request.") + private SignedPreKey pqLastResortPreKey; + @JsonProperty @NotEmpty + @NotNull + @Schema(description="Required. " + + "The public identity key for this identity (account or phone-number identity). " + + "If this device is not the primary device for the account, " + + "must match the existing stored identity key for this identity.") private String identityKey; public PreKeyState() {} @VisibleForTesting public PreKeyState(String identityKey, SignedPreKey signedPreKey, List keys) { - this.identityKey = identityKey; - this.signedPreKey = signedPreKey; - this.preKeys = keys; + this(identityKey, signedPreKey, keys, null, null); + } + + @VisibleForTesting + public PreKeyState(String identityKey, SignedPreKey signedPreKey, List keys, List pqKeys, SignedPreKey pqLastResortKey) { + this.identityKey = identityKey; + this.signedPreKey = signedPreKey; + this.preKeys = keys; + this.pqPreKeys = pqKeys; + this.pqLastResortPreKey = pqLastResortKey; } public List getPreKeys() { @@ -45,12 +80,30 @@ public class PreKeyState { return signedPreKey; } + public List getPqPreKeys() { + return pqPreKeys; + } + + public SignedPreKey getPqLastResortPreKey() { + return pqLastResortPreKey; + } + public String getIdentityKey() { return identityKey; } @AssertTrue - public boolean isSignatureValid() { - return PreKeySignatureValidator.validatePreKeySignature(identityKey, signedPreKey); + public boolean isSignatureValidOnEachSignedKey() { + List spks = new ArrayList<>(); + if (pqPreKeys != null) { + spks.addAll(pqPreKeys); + } + if (pqLastResortPreKey != null) { + spks.add(pqLastResortPreKey); + } + if (signedPreKey != null) { + spks.add(signedPreKey); + } + return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java index 2b9e301f8..cb63f2d40 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SignedPreKey.java @@ -45,5 +45,4 @@ public class SignedPreKey extends PreKey { return super.hashCode() ^ signature.hashCode(); } } - } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java index 313985860..a5abc8d56 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AbstractDynamoDbStore.java @@ -12,6 +12,7 @@ import static io.micrometer.core.instrument.Metrics.timer; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Timer; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -53,7 +54,7 @@ public abstract class AbstractDynamoDbStore { return dynamoDbClient; } - protected void executeTableWriteItemsUntilComplete(final Map> items) { + protected void executeTableWriteItemsUntilComplete(final Map> items) { final AtomicReference outcome = new AtomicReference<>(); writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome); int attemptCount = 0; @@ -80,7 +81,7 @@ public abstract class AbstractDynamoDbStore { } private void writeAndStoreOutcome( - final Map> items, + final Map> items, final Timer timer, final AtomicReference outcome) { timer.record( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index bf029ef20..051e5b2c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -245,6 +245,7 @@ public class AccountsManager { public Account changeNumber(final Account account, final String number, @Nullable final String pniIdentityKey, @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { final String originalNumber = account.getNumber(); @@ -252,12 +253,12 @@ public class AccountsManager { if (originalNumber.equals(number)) { if (pniIdentityKey != null) { - throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePNIKeys"); + throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys"); } return account; } - validateDevices(account, pniSignedPreKeys, pniRegistrationIds); + validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); final AtomicReference updatedAccount = new AtomicReference<>(); @@ -281,7 +282,7 @@ public class AccountsManager { numberChangedAccount = updateWithRetries( account, - a -> setPNIKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), + a -> { setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); return true; }, a -> accounts.changeNumber(a, number, phoneNumberIdentifier), () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); @@ -291,45 +292,74 @@ public class AccountsManager { keys.delete(phoneNumberIdentifier); keys.delete(originalPhoneNumberIdentifier); + if (pniPqLastResortPreKeys != null) { + keys.storePqLastResort( + phoneNumberIdentifier, + keys.getPqEnabledDevices(uuid).stream().collect( + Collectors.toMap( + Function.identity(), + pniPqLastResortPreKeys::get))); + } + return displacedUuid; }); return updatedAccount.get(); } - public Account updatePNIKeys(final Account account, + public Account updatePniKeys(final Account account, final String pniIdentityKey, final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, final Map pniRegistrationIds) throws MismatchedDevicesException { - validateDevices(account, pniSignedPreKeys, pniRegistrationIds); + validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); - return update(account, a -> { return setPNIKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); + final UUID pni = account.getPhoneNumberIdentifier(); + final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); + + final List pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni); + keys.delete(pni); + if (pniPqLastResortPreKeys != null) { + keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))); + } + + return updatedAccount; } - private boolean setPNIKeys(final Account account, + private boolean setPniKeys(final Account account, @Nullable final String pniIdentityKey, @Nullable final Map pniSignedPreKeys, @Nullable final Map pniRegistrationIds) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { - return true; + return false; } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); } - pniSignedPreKeys.forEach((deviceId, signedPreKey) -> - account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey))); + boolean changed = !pniIdentityKey.equals(account.getPhoneNumberIdentityKey()); + + for (Device device : account.getDevices()) { + if (!device.isEnabled()) { + continue; + } + SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId()); + int registrationId = pniRegistrationIds.get(device.getId()); + changed = changed || + !signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) || + device.getRegistrationId() != registrationId; + device.setPhoneNumberIdentitySignedPreKey(signedPreKey); + device.setPhoneNumberIdentityRegistrationId(registrationId); + } - pniRegistrationIds.forEach((deviceId, registrationId) -> - account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId))); + account.setPhoneNumberIdentityKey(pniIdentityKey); - account.setPhoneNumberIdentityKey(pniIdentityKey); - - return true; + return changed; } private void validateDevices(final Account account, - final Map pniSignedPreKeys, - final Map pniRegistrationIds) throws MismatchedDevicesException { + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, + @Nullable final Map pniRegistrationIds) throws MismatchedDevicesException { if (pniSignedPreKeys == null && pniRegistrationIds == null) { return; } else if (pniSignedPreKeys == null || pniRegistrationIds == null) { @@ -342,6 +372,12 @@ public class AccountsManager { pniSignedPreKeys.keySet(), Collections.emptySet()); + // Check that all including master ID are in Pq pre-keys + DestinationDeviceValidator.validateCompleteDeviceList( + account, + pniSignedPreKeys.keySet(), + Collections.emptySet()); + // Check that all devices are accounted for in the map of new PNI registration IDs DestinationDeviceValidator.validateCompleteDeviceList( account, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index c83ff1831..3015bdc2f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -42,6 +42,7 @@ public class ChangeNumberManager { public Account changeNumber(final Account account, final String number, @Nullable final String pniIdentityKey, @Nullable final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, @Nullable final List deviceMessages, @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException, StaleDevicesException { @@ -62,10 +63,14 @@ public class ChangeNumberManager { // We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key // material and distribute the sync messages, to be sure all clients agree with us and each other about what their // keys are. Pretend this change-number request was actually a PNI key distribution request. - return updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds); + if (pniIdentityKey == null) { + return account; + } + return updatePniKeys(account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, deviceMessages, pniRegistrationIds); } - final Account updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); + final Account updatedAccount = accountsManager.changeNumber( + account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds); if (deviceMessages != null) { sendDeviceMessages(updatedAccount, deviceMessages); @@ -74,16 +79,18 @@ public class ChangeNumberManager { return updatedAccount; } - public Account updatePNIKeys(final Account account, + public Account updatePniKeys(final Account account, final String pniIdentityKey, final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, final List deviceMessages, final Map pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { validateDeviceMessages(account, deviceMessages); // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb // write anyway. Linked devices can handle some wasted extra key rotations. - final Account updatedAccount = accountsManager.updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); + final Account updatedAccount = accountsManager.updatePniKeys( + account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds); sendDeviceMessages(updatedAccount, deviceMessages); return updatedAccount; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index b9a35a268..4c92c0377 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -6,6 +6,9 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Multimap; +import com.google.common.collect.MultimapBuilder; +import com.google.common.collect.Multimaps; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; @@ -16,7 +19,11 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -34,11 +41,14 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; public class Keys extends AbstractDynamoDbStore { - private final String tableName; + private final String ecTableName; + private final String pqTableName; + private final String pqLastResortTableName; static final String KEY_ACCOUNT_UUID = "U"; static final String KEY_DEVICE_ID_KEY_ID = "DK"; static final String KEY_PUBLIC_KEY = "P"; + static final String KEY_SIGNATURE = "S"; private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys")); private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice")); @@ -48,31 +58,114 @@ public class Keys extends AbstractDynamoDbStore { private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys")); private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount")); private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty")); + private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys")); - public Keys(final DynamoDbClient dynamoDB, final String tableName) { + public Keys( + final DynamoDbClient dynamoDB, + final String ecTableName, + final String pqTableName, + final String pqLastResortTableName) { super(dynamoDB); - this.tableName = tableName; + this.ecTableName = ecTableName; + this.pqTableName = pqTableName; + this.pqLastResortTableName = pqLastResortTableName; } public void store(final UUID identifier, final long deviceId, final List keys) { - STORE_KEYS_TIMER.record(() -> { - delete(identifier, deviceId); + store(identifier, deviceId, keys, null, null); + } - writeInBatches(keys, batch -> { - List items = new ArrayList<>(); - for (final PreKey preKey : batch) { - items.add(WriteRequest.builder() - .putRequest(PutRequest.builder() - .item(getItemFromPreKey(identifier, deviceId, preKey)) - .build()) - .build()); - } - executeTableWriteItemsUntilComplete(Map.of(tableName, items)); + public void store( + final UUID identifier, final long deviceId, + @Nullable final List ecKeys, + @Nullable final List pqKeys, + @Nullable final SignedPreKey pqLastResortKey) { + Multimap keys = MultimapBuilder.hashKeys().arrayListValues().build(); + List tablesToClear = new ArrayList<>(); + + if (ecKeys != null && !ecKeys.isEmpty()) { + keys.putAll(ecTableName, ecKeys); + tablesToClear.add(ecTableName); + } + if (pqKeys != null && !pqKeys.isEmpty()) { + keys.putAll(pqTableName, pqKeys); + tablesToClear.add(pqTableName); + } + if (pqLastResortKey != null) { + keys.put(pqLastResortTableName, pqLastResortKey); + tablesToClear.add(pqLastResortTableName); + } + + STORE_KEYS_TIMER.record(() -> { + delete(tablesToClear, identifier, deviceId); + + writeInBatches( + keys.entries(), + batch -> { + Multimap writes = batch.stream() + .collect( + Multimaps.toMultimap( + Map.Entry::getKey, + entry -> WriteRequest.builder() + .putRequest(PutRequest.builder() + .item(getItemFromPreKey(identifier, deviceId, entry.getValue())) + .build()) + .build(), + MultimapBuilder.hashKeys().arrayListValues()::build)); + executeTableWriteItemsUntilComplete(writes.asMap()); }); }); } - public Optional take(final UUID identifier, final long deviceId) { + public void storePqLastResort(final UUID identifier, final Map keys) { + final AttributeValue partitionKey = getPartitionKey(identifier); + final QueryRequest queryRequest = QueryRequest.builder() + .tableName(pqLastResortTableName) + .keyConditionExpression("#uuid = :uuid") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of(":uuid", partitionKey)) + .projectionExpression(KEY_DEVICE_ID_KEY_ID) + .consistentRead(true) + .build(); + + final List writes = new ArrayList<>(2 * keys.size()); + final Map> newItems = keys.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue()))); + + for (final Map item : db().query(queryRequest).items()) { + final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID); + final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong(); + if (newItems.containsKey(oldDeviceId)) { + final Map replacement = newItems.get(oldDeviceId); + if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) { + writes.add(WriteRequest.builder() + .deleteRequest(DeleteRequest.builder() + .key(Map.of( + KEY_ACCOUNT_UUID, partitionKey, + KEY_DEVICE_ID_KEY_ID, oldSortKey)) + .build()) + .build()); + } + } + } + + newItems.forEach((unusedKey, item) -> + writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build())); + + executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes)); + } + + public Optional takeEC(final UUID identifier, final long deviceId) { + return take(ecTableName, identifier, deviceId); + } + + public Optional takePQ(final UUID identifier, final long deviceId) { + return take(pqTableName, identifier, deviceId) + .or(() -> getLastResort(identifier, deviceId)) + .map(pk -> (SignedPreKey) pk); + } + + private Optional take(final String tableName, final UUID identifier, final long deviceId) { return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { final AttributeValue partitionKey = getPartitionKey(identifier); QueryRequest queryRequest = QueryRequest.builder() @@ -114,7 +207,53 @@ public class Keys extends AbstractDynamoDbStore { }); } - public int getCount(final UUID identifier, final long deviceId) { + @VisibleForTesting + Optional getLastResort(final UUID identifier, final long deviceId) { + final AttributeValue partitionKey = getPartitionKey(identifier); + QueryRequest queryRequest = QueryRequest.builder() + .tableName(pqLastResortTableName) + .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .expressionAttributeValues(Map.of( + ":uuid", partitionKey, + ":sortprefix", getSortKeyPrefix(deviceId))) + .consistentRead(false) + .select(Select.ALL_ATTRIBUTES) + .build(); + + QueryResponse response = db().query(queryRequest); + if (response.count() > 1) { + TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment(); + } + return response.items().stream().findFirst().map(this::getPreKeyFromItem); + } + + public List getPqEnabledDevices(final UUID identifier) { + final AttributeValue partitionKey = getPartitionKey(identifier); + final QueryRequest queryRequest = QueryRequest.builder() + .tableName(pqLastResortTableName) + .keyConditionExpression("#uuid = :uuid") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of(":uuid", partitionKey)) + .projectionExpression(KEY_DEVICE_ID_KEY_ID) + .consistentRead(false) + .build(); + + final QueryResponse response = db().query(queryRequest); + return response.items().stream() + .map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong()) + .toList(); + } + + public int getEcCount(final UUID identifier, final long deviceId) { + return getCount(ecTableName, identifier, deviceId); + } + + public int getPqCount(final UUID identifier, final long deviceId) { + return getCount(pqTableName, identifier, deviceId); + } + + private int getCount(final String tableName, final UUID identifier, final long deviceId) { return GET_KEY_COUNT_TIMER.record(() -> { QueryRequest queryRequest = QueryRequest.builder() .tableName(tableName) @@ -144,51 +283,66 @@ public class Keys extends AbstractDynamoDbStore { public void delete(final UUID accountUuid) { DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { final QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) .keyConditionExpression("#uuid = :uuid") .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(accountUuid))) + ":uuid", getPartitionKey(accountUuid))) .projectionExpression(KEY_DEVICE_ID_KEY_ID) .consistentRead(true) .build(); - deleteItemsForAccountMatchingQuery(accountUuid, queryRequest); + deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest); }); } public void delete(final UUID accountUuid, final long deviceId) { + delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId); + } + + private void delete(final List tableNames, final UUID accountUuid, final long deviceId) { DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { final QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(accountUuid), - ":sortprefix", getSortKeyPrefix(deviceId))) + ":uuid", getPartitionKey(accountUuid), + ":sortprefix", getSortKeyPrefix(deviceId))) .projectionExpression(KEY_DEVICE_ID_KEY_ID) .consistentRead(true) .build(); - deleteItemsForAccountMatchingQuery(accountUuid, queryRequest); + deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest); }); } - private void deleteItemsForAccountMatchingQuery(final UUID accountUuid, final QueryRequest querySpec) { + private void deleteItemsForAccountMatchingQuery(final List tableNames, final UUID accountUuid, final QueryRequest querySpec) { final AttributeValue partitionKey = getPartitionKey(accountUuid); - writeInBatches(db().query(querySpec).items(), batch -> { - List deletes = new ArrayList<>(); - for (final Map item : batch) { - deletes.add(WriteRequest.builder() - .deleteRequest(DeleteRequest.builder() - .key(Map.of( - KEY_ACCOUNT_UUID, partitionKey, - KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID))) - .build()) - .build()); - } - executeTableWriteItemsUntilComplete(Map.of(tableName, deletes)); + Multimap> itemStream = tableNames.stream() + .collect( + Multimaps.flatteningToMultimap( + Function.identity(), + tableName -> + db().query(querySpec.toBuilder().tableName(tableName).build()) + .items() + .stream(), + MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build)); + + writeInBatches( + itemStream.entries(), + batch -> { + Multimap deletes = batch.stream() + .collect(Multimaps.toMultimap( + Map.Entry>::getKey, + entry -> WriteRequest.builder() + .deleteRequest(DeleteRequest.builder() + .key(Map.of( + KEY_ACCOUNT_UUID, partitionKey, + KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID))) + .build()) + .build(), + MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build)); + executeTableWriteItemsUntilComplete(deletes.asMap()); }); } @@ -211,6 +365,13 @@ public class Keys extends AbstractDynamoDbStore { } private Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { + if (preKey instanceof final SignedPreKey spk) { + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), + KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()), + KEY_PUBLIC_KEY, AttributeValues.fromString(spk.getPublicKey()), + KEY_SIGNATURE, AttributeValues.fromString(spk.getSignature())); + } return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), @@ -219,6 +380,11 @@ public class Keys extends AbstractDynamoDbStore { private PreKey getPreKeyFromItem(Map item) { final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + if (item.containsKey(KEY_SIGNATURE)) { + // All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored + // in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys. + return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s()); + } return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java index 058417d39..3179b921c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java @@ -174,7 +174,9 @@ public class AssignUsernameCommand extends EnvironmentCommand) invocation -> { + when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); final String pniIdentityKey = invocation.getArgument(2, String.class); @@ -358,7 +358,7 @@ class AccountControllerTest { return updatedAccount; }); - when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { + when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String pniIdentityKey = invocation.getArgument(1, String.class); @@ -1377,12 +1377,12 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any()); assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.number()).isEqualTo(number); @@ -1399,12 +1399,12 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.readEntity(String.class)).isBlank(); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1417,7 +1417,7 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(400); @@ -1426,7 +1426,7 @@ class AccountControllerTest { assertThat(responseEntity.getOriginalNumber()).isEqualTo(number); assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111"); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1436,10 +1436,10 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any()); } @Test @@ -1454,11 +1454,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1478,13 +1478,13 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT); assertThat(response.getStatus()).isEqualTo(403); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1514,11 +1514,11 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any()); } @Test @@ -1549,14 +1549,14 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); // verify(existingAccount).lockAuthenticationCredentials(); // verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any()); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1589,14 +1589,14 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(423); // verify(existingAccount).lockAuthenticationCredentials(); // verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any()); - verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); + verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any()); } @Test @@ -1628,13 +1628,13 @@ class AccountControllerTest { .target("/v1/accounts/number") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), + .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(200); verify(senderRegLockAccount, never()).lockAuthTokenHash(); verify(clientPresenceManager, never()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any()); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any()); } @Test @@ -1681,10 +1681,11 @@ class AccountControllerTest { number, code, null, pniIdentityKey, deviceMessages, deviceKeys, + null, registrationIds), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); - verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any()); assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.number()).isEqualTo(number); @@ -1734,11 +1735,12 @@ class AccountControllerTest { AuthHelper.VALID_NUMBER, code, null, pniIdentityKey, deviceMessages, deviceKeys, + null, registrationIds), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(changeNumberManager).changeNumber( - eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any()); + eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any(), any()); verifyNoInteractions(rateLimiter); verifyNoInteractions(pendingAccountsManager); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index d9bd36f77..ac47d4c25 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -134,7 +134,7 @@ class AccountControllerV2Test { void setUp() throws Exception { when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); - when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer( + when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer( (Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); @@ -180,11 +180,11 @@ class AccountControllerV2Test { .put(Entity.entity( new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123", Collections.emptyList(), - Collections.emptyMap(), Collections.emptyMap()), + Collections.emptyMap(), null, Collections.emptyMap()), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), - any()); + any(), any()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(NEW_NUMBER, accountIdentityResponse.number()); @@ -203,11 +203,11 @@ class AccountControllerV2Test { new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, "pni-identity-key", Collections.emptyList(), - Collections.emptyMap(), Collections.emptyMap()), + Collections.emptyMap(), null, Collections.emptyMap()), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), - any()); + any(), any()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); @@ -365,7 +365,7 @@ class AccountControllerV2Test { final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), - any()); + any(), any()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(NEW_NUMBER, accountIdentityResponse.number()); @@ -458,7 +458,7 @@ class AccountControllerV2Test { @BeforeEach void setUp() throws Exception { - when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer( + when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer( (Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); final String pniIdentityKey = invocation.getArgument(1, String.class); @@ -496,7 +496,7 @@ class AccountControllerV2Test { AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(requestJson()), AccountIdentityResponse.class); - verify(changeNumberManager).updatePNIKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any()); + verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any(), any()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); @@ -557,6 +557,7 @@ class AccountControllerV2Test { "pniIdentityKey": "pni-identity-key", "deviceMessages": [], "devicePniSignedPrekeys": {}, + "devicePniSignedPqPrekeys": {}, "pniRegistrationIds": {} } """; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index f4511b0b9..4350fd9c4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -128,7 +128,7 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - accountsManager.changeNumber(account, secondNumber, null, null, null); + accountsManager.changeNumber(account, secondNumber, null, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -161,7 +161,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Map preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); - final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, registrationIds); + final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -191,8 +191,8 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - account = accountsManager.changeNumber(account, secondNumber, null, null, null); - accountsManager.changeNumber(account, originalNumber, null, null, null); + account = accountsManager.changeNumber(account, secondNumber, null, null, null, null); + accountsManager.changeNumber(account, originalNumber, null, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isPresent()); assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow()); @@ -217,7 +217,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final UUID existingAccountUuid = existingAccount.getUuid(); - accountsManager.changeNumber(account, secondNumber, null, null, null); + accountsManager.changeNumber(account, secondNumber, null, null, null, null); assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); @@ -231,7 +231,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); - accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null); + accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null); final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); @@ -251,7 +251,7 @@ class AccountsManagerChangeNumberIntegrationTest { final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final UUID existingAccountUuid = existingAccount.getUuid(); - final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null); + final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); @@ -262,7 +262,7 @@ class AccountsManagerChangeNumberIntegrationTest { assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); - final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null); + final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null, null); assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 138a9c967..7b4f18da1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -641,7 +642,7 @@ class AccountsManagerTest { final UUID originalPni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, targetNumber, null, null, null); + account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); assertEquals(targetNumber, account.getNumber()); @@ -656,7 +657,7 @@ class AccountsManagerTest { final String number = "+14152222222"; Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, number, null, null, null); + account = accountsManager.changeNumber(account, number, null, null, null, null); assertEquals(number, account.getNumber()); verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any()); @@ -664,13 +665,13 @@ class AccountsManagerTest { } @Test - void testChangePhoneNumberSameNumberWithPNIData() { + void testChangePhoneNumberSameNumberWithPniData() { final String number = "+14152222222"; Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); assertThrows(IllegalArgumentException.class, () -> accountsManager.changeNumber( - account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), Map.of(1L, 101)), + account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)), "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); verify(accounts, never()).update(any()); @@ -694,14 +695,60 @@ class AccountsManagerTest { when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); - account = accountsManager.changeNumber(account, targetNumber, null, null, null); + account = accountsManager.changeNumber(account, targetNumber, null, null, null, null); assertEquals(targetNumber, account.getNumber()); assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); + final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); + verify(keys).delete(existingAccountUuid); verify(keys).delete(originalPni); - verify(keys).delete(targetPni); + verify(keys, atLeastOnce()).delete(targetPni); + verify(keys).delete(newPni); + verifyNoMoreInteractions(keys); + } + + @Test + void testChangePhoneNumberWithPqKeysExistingAccount() throws InterruptedException, MismatchedDevicesException { + doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty())) + .when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any()); + + final String originalNumber = "+14152222222"; + final String targetNumber = "+14153333333"; + final UUID existingAccountUuid = UUID.randomUUID(); + final UUID uuid = UUID.randomUUID(); + final UUID originalPni = UUID.randomUUID(); + final UUID targetPni = UUID.randomUUID(); + final Map newSignedKeys = Map.of( + 1L, new SignedPreKey(1L, "pub1", "sig1"), + 2L, new SignedPreKey(2L, "pub2", "sig2")); + final Map newSignedPqKeys = Map.of( + 1L, new SignedPreKey(3L, "pub3", "sig3"), + 2L, new SignedPreKey(4L, "pub4", "sig4")); + final Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]); + when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); + when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L)); + + final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); + final Account updatedAccount = accountsManager.changeNumber( + account, targetNumber, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds); + + assertEquals(targetNumber, updatedAccount.getNumber()); + + assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); + + final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); + verify(keys).delete(existingAccountUuid); + verify(keys, atLeastOnce()).delete(targetPni); + verify(keys).delete(newPni); + verify(keys).delete(originalPni); + verify(keys).getPqEnabledDevices(uuid); + verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, new SignedPreKey(3L, "pub3", "sig3")))); + verifyNoMoreInteractions(keys); } @Test @@ -716,7 +763,7 @@ class AccountsManagerTest { } @Test - void testPNIUpdate() throws MismatchedDevicesException { + void testPniUpdate() throws MismatchedDevicesException { final String number = "+14152222222"; List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); @@ -730,7 +777,7 @@ class AccountsManagerTest { UUID oldPni = account.getPhoneNumberIdentifier(); Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); - final Account updatedAccount = accountsManager.updatePNIKeys(account, "new-pni-identity-key", newSignedKeys, newRegistrationIds); + final Account updatedAccount = accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, null, newRegistrationIds); // non-PNI stuff should not change assertEquals(oldUuid, updatedAccount.getUuid()); @@ -750,7 +797,57 @@ class AccountsManagerTest { verify(accounts).update(any()); verifyNoInteractions(deletedAccountsManager); - verifyNoInteractions(keys); + + verify(keys).delete(oldPni); + } + + @Test + void testPniPqUpdate() throws MismatchedDevicesException { + final String number = "+14152222222"; + + List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]); + Map newSignedKeys = Map.of( + 1L, new SignedPreKey(1L, "pub1", "sig1"), + 2L, new SignedPreKey(2L, "pub2", "sig2")); + Map newSignedPqKeys = Map.of( + 1L, new SignedPreKey(3L, "pub3", "sig3"), + 2L, new SignedPreKey(4L, "pub4", "sig4")); + Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + + UUID oldUuid = account.getUuid(); + UUID oldPni = account.getPhoneNumberIdentifier(); + + when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); + + Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); + + final Account updatedAccount = + accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds); + + // non-PNI-keys stuff should not change + assertEquals(oldUuid, updatedAccount.getUuid()); + assertEquals(number, updatedAccount.getNumber()); + assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); + assertEquals(null, updatedAccount.getIdentityKey()); + assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey))); + assertEquals(Map.of(1L, 101, 2L, 102), + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); + + // PNI keys should + assertEquals("new-pni-identity-key", updatedAccount.getPhoneNumberIdentityKey()); + assertEquals(newSignedKeys, + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey))); + assertEquals(newRegistrationIds, + updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); + + verify(accounts).update(any()); + verifyNoInteractions(deletedAccountsManager); + + verify(keys).delete(oldPni); + + // only the pq key for the already-pq-enabled device should be saved + verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index a995939a2..531425092 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -47,7 +47,7 @@ public class ChangeNumberManagerTest { updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); - when(accountsManager.changeNumber(any(), any(), any(), any(), any())).thenAnswer((Answer)invocation -> { + when(accountsManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer)invocation -> { final Account account = invocation.getArgument(0, Account.class); final String number = invocation.getArgument(1, String.class); @@ -70,7 +70,7 @@ public class ChangeNumberManagerTest { return updatedAccount; }); - when(accountsManager.updatePNIKeys(any(), any(), any(), any())).thenAnswer((Answer)invocation -> { + when(accountsManager.updatePniKeys(any(), any(), any(), any(), any())).thenAnswer((Answer)invocation -> { final Account account = invocation.getArgument(0, Account.class); final UUID uuid = account.getUuid(); @@ -94,8 +94,8 @@ public class ChangeNumberManagerTest { void changeNumberNoMessages() throws Exception { Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); - changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null); - verify(accountsManager).changeNumber(account, "+18025551234", null, null, null); + changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null); + verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); verify(accountsManager, never()).updateDevice(any(), eq(1L), any()); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); } @@ -107,8 +107,8 @@ public class ChangeNumberManagerTest { var prekeys = Map.of(1L, new SignedPreKey()); final String pniIdentityKey = "pni-identity-key"; - changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyList(), Collections.emptyMap()); - verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyMap()); + changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); + verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); } @@ -139,9 +139,53 @@ public class ChangeNumberManagerTest { when(msg.destinationDeviceId()).thenReturn(2L); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); - changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); + changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); - verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, registrationIds); + verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); + + final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); + verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + + assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); + assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(Device.MASTER_ID, envelope.getSourceDevice()); + assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); + } + + + @Test + void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { + final String originalE164 = "+18005551234"; + final String changedE164 = "+18025551234"; + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getNumber()).thenReturn(originalE164); + when(account.getUuid()).thenReturn(aci); + when(account.getPhoneNumberIdentifier()).thenReturn(pni); + + final Device d2 = mock(Device.class); + when(d2.isEnabled()).thenReturn(true); + when(d2.getId()).thenReturn(2L); + + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevices()).thenReturn(List.of(d2)); + + final String pniIdentityKey = "pni-identity-key"; + final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 19); + + final IncomingMessage msg = mock(IncomingMessage.class); + when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + + changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); + + verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds); final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); @@ -174,15 +218,16 @@ public class ChangeNumberManagerTest { final String pniIdentityKey = "pni-identity-key"; final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); final Map registrationIds = Map.of(1L, 17, 2L, 19); final IncomingMessage msg = mock(IncomingMessage.class); when(msg.destinationDeviceId()).thenReturn(2L); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); - changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); + changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); - verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); + verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); @@ -196,7 +241,7 @@ public class ChangeNumberManagerTest { } @Test - void updatePNIKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { + void updatePniKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { final UUID aci = UUID.randomUUID(); final UUID pni = UUID.randomUUID(); @@ -219,9 +264,49 @@ public class ChangeNumberManagerTest { when(msg.destinationDeviceId()).thenReturn(2L); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); - changeNumberManager.updatePNIKeys(account, pniIdentityKey, prekeys, List.of(msg), registrationIds); + changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); - verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); + verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds); + + final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); + verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); + + final MessageProtos.Envelope envelope = envelopeCaptor.getValue(); + + assertEquals(aci, UUID.fromString(envelope.getDestinationUuid())); + assertEquals(aci, UUID.fromString(envelope.getSourceUuid())); + assertEquals(Device.MASTER_ID, envelope.getSourceDevice()); + assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account)); + } + + @Test + void updatePniKeysSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { + final UUID aci = UUID.randomUUID(); + final UUID pni = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(aci); + when(account.getPhoneNumberIdentifier()).thenReturn(pni); + + final Device d2 = mock(Device.class); + when(d2.isEnabled()).thenReturn(true); + when(d2.getId()).thenReturn(2L); + + when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevices()).thenReturn(List.of(d2)); + + final String pniIdentityKey = "pni-identity-key"; + final Map prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); + final Map pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey()); + final Map registrationIds = Map.of(1L, 17, 2L, 19); + + final IncomingMessage msg = mock(IncomingMessage.class); + when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); + + changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); + + verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds); final ArgumentCaptor envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); @@ -261,11 +346,11 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, - () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds)); + () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, null, messages, registrationIds)); } @Test - void updatePNIKeysMismatchedRegistrationId() { + void updatePniKeysMismatchedRegistrationId() { final Account account = mock(Account.class); when(account.getNumber()).thenReturn("+18005551234"); @@ -291,7 +376,7 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(StaleDevicesException.class, - () -> changeNumberManager.updatePNIKeys(account, "pni-identity-key", preKeys, messages, registrationIds)); + () -> changeNumberManager.updatePniKeys(account, "pni-identity-key", preKeys, null, messages, registrationIds)); } @Test @@ -320,6 +405,6 @@ public class ChangeNumberManagerTest { final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); assertThrows(IllegalArgumentException.class, - () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, messages, registrationIds)); + () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, null, messages, registrationIds)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java index a4b5fea3a..9cff9c791 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java @@ -69,7 +69,35 @@ public final class DynamoDbExtensionSchema { .build()), List.of(), List.of()), - KEYS("keys_test", + EC_KEYS("keys_test", + Keys.KEY_ACCOUNT_UUID, + Keys.KEY_DEVICE_ID_KEY_ID, + List.of( + AttributeDefinition.builder() + .attributeName(Keys.KEY_ACCOUNT_UUID) + .attributeType(ScalarAttributeType.B) + .build(), + AttributeDefinition.builder() + .attributeName(Keys.KEY_DEVICE_ID_KEY_ID) + .attributeType(ScalarAttributeType.B) + .build()), + List.of(), List.of()), + + PQ_KEYS("pq_keys_test", + Keys.KEY_ACCOUNT_UUID, + Keys.KEY_DEVICE_ID_KEY_ID, + List.of( + AttributeDefinition.builder() + .attributeName(Keys.KEY_ACCOUNT_UUID) + .attributeType(ScalarAttributeType.B) + .build(), + AttributeDefinition.builder() + .attributeName(Keys.KEY_DEVICE_ID_KEY_ID) + .attributeType(ScalarAttributeType.B) + .build()), + List.of(), List.of()), + + PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test", Keys.KEY_ACCOUNT_UUID, Keys.KEY_DEVICE_ID_KEY_ID, List.of( diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java index 6c21e4f1f..cb6b4f8a0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java @@ -6,99 +6,244 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.util.AttributeValues; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; +import software.amazon.awssdk.services.dynamodb.model.Select; class KeysTest { private Keys keys; @RegisterExtension - static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.KEYS); + static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( + Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS); private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final long DEVICE_ID = 1L; @BeforeEach void setup() { - keys = new Keys(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.KEYS.tableName()); + keys = new Keys( + DYNAMO_DB_EXTENSION.getDynamoDbClient(), + Tables.EC_KEYS.tableName(), + Tables.PQ_KEYS.tableName(), + Tables.PQ_LAST_RESORT_KEYS.tableName()); } @Test void testStore() { - assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Initial pre-key count for an account should be zero"); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Initial pre-key count for an account should be zero"); + assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), + "Initial last-resort pre-key for an account should be missing"); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID), + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Repeatedly storing same key should have no effect"); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting a new key should overwrite all prior keys for the given account/device"); + keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1, "pq-public-key", "sig")), null); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ prekeys should have no effect on EC prekeys"); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); - assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID), + keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, new SignedPreKey(1001, "pq-last-resort-key", "sig")); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); + assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); + + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")), null, null); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new EC prekeys should have no effect on PQ prekeys"); + + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key")), List.of(new SignedPreKey(2, "different-pq-public-key", "sig")), null); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + + keys.store(ACCOUNT_UUID, DEVICE_ID, + List.of(new PreKey(4, "fourth-public-key"), new PreKey(5, "fifth-public-key")), + List.of(new SignedPreKey(6, "sixth-pq-key", "sig"), new SignedPreKey(7, "seventh-pq-key", "sig")), + new SignedPreKey(1002, "new-last-resort-key", "sig")); + assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); + assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting multiple new keys should overwrite all prior keys for the given account/device"); + assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(), + "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); } @Test void testTakeAccountAndDeviceId() { - assertEquals(Optional.empty(), keys.take(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); final PreKey preKey = new PreKey(1, "public-key"); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); - assertEquals(Optional.of(preKey), keys.take(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(Optional.of(preKey), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + } + + @Test + void testTakePQ() { + assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); + + final SignedPreKey preKey1 = new SignedPreKey(1, "public-key", "sig"); + final SignedPreKey preKey2 = new SignedPreKey(2, "different-public-key", "sig"); + final SignedPreKey preKeyLast = new SignedPreKey(1001, "last-public-key", "sig"); + + keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); + + assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); } @Test void testGetCount() { - assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")), List.of(new SignedPreKey(1, "public-pq-key", "sig")), null); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); } @Test void testDeleteByAccount() { - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); - keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, + List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")), + List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")), + new SignedPreKey(5, "last-pq-key", "sig")); - assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); + keys.store(ACCOUNT_UUID, DEVICE_ID + 1, + List.of(new PreKey(6, "public-key-for-different-device")), + List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")), + new SignedPreKey(8, "last-pq-key-for-different-device", "sig")); + + assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); keys.delete(ACCOUNT_UUID); - assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); } @Test void testDeleteByAccountAndDevice() { - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); - keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, + List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")), + List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")), + new SignedPreKey(5, "last-pq-key", "sig")); - assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); + keys.store(ACCOUNT_UUID, DEVICE_ID + 1, + List.of(new PreKey(6, "public-key-for-different-device")), + List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")), + new SignedPreKey(8, "last-pq-key-for-different-device", "sig")); + + assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); keys.delete(ACCOUNT_UUID, DEVICE_ID); - assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + } + + @Test + void testStorePqLastResort() { + assertEquals(0, getLastResortCount(ACCOUNT_UUID)); + + keys.storePqLastResort( + ACCOUNT_UUID, + Map.of(1L, new SignedPreKey(1L, "pub1", "sig1"), 2L, new SignedPreKey(2L, "pub2", "sig2"))); + assertEquals(2, getLastResortCount(ACCOUNT_UUID)); + assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId()); + assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId()); + assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent()); + + keys.storePqLastResort( + ACCOUNT_UUID, + Map.of(1L, new SignedPreKey(3L, "pub3", "sig3"), 3L, new SignedPreKey(4L, "pub4", "sig4"))); + assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates"); + assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone"); + assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); + } + + private int getLastResortCount(UUID uuid) { + QueryRequest queryRequest = QueryRequest.builder() + .tableName(Tables.PQ_LAST_RESORT_KEYS.tableName()) + .keyConditionExpression("#uuid = :uuid") + .expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid))) + .select(Select.COUNT) + .build(); + QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest); + return response.count(); + } + + @Test + void testGetPqEnabledDevices() { + keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1L, "pub1", "sig1")), null); + keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, new SignedPreKey(2L, "pub2", "sig2")); + keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(new SignedPreKey(3L, "pub3", "sig3")), new SignedPreKey(4L, "pub4", "sig4")); + keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null); + assertIterableEquals( + Set.of(DEVICE_ID + 1, DEVICE_ID + 2), + Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID))); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index d4c7410f9..02a985cb0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; @@ -86,19 +87,25 @@ class KeysControllerTest { private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final String PNI_IDENTITY_KEY = KeysHelper.serializeIdentityKey(PNI_IDENTITY_KEY_PAIR); - private final PreKey SAMPLE_KEY = new PreKey(1234, "test1"); - private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3"); - private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5"); - private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6"); + private final PreKey SAMPLE_KEY = new PreKey(1234, "test1"); + private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3"); + private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5"); + private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6"); private final PreKey SAMPLE_KEY_PNI = new PreKey(7777, "test7"); - private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey( 1111, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey( 2222, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey( 3333, IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey( 4444, PNI_IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey( 5555, PNI_IDENTITY_KEY_PAIR); - private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey( 6666, PNI_IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_PQ_KEY = new SignedPreKey(2424, "test1", "sig"); + private final SignedPreKey SAMPLE_PQ_KEY2 = new SignedPreKey(6868, "test3", "sig"); + private final SignedPreKey SAMPLE_PQ_KEY3 = new SignedPreKey(1313, "test5", "sig"); + + private final SignedPreKey SAMPLE_PQ_KEY_PNI = new SignedPreKey(8888, "test7", "sig"); + + private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey(1111, IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey(2222, IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey(3333, IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey(4444, PNI_IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey(5555, PNI_IDENTITY_KEY_PAIR); + private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey(6666, PNI_IDENTITY_KEY_PAIR); private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedPreKey(89898, IDENTITY_KEY_PAIR); private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedPreKey(7777, PNI_IDENTITY_KEY_PAIR); @@ -177,10 +184,13 @@ class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.take(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); + when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); + when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI)); - when(KEYS.getCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); + when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); + when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); @@ -210,8 +220,10 @@ class KeysControllerTest { .get(PreKeyCount.class); assertThat(result.getCount()).isEqualTo(5); + assertThat(result.getPqCount()).isEqualTo(5); - verify(KEYS).getCount(AuthHelper.VALID_UUID, 1); + verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1); + verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1); } @@ -223,9 +235,7 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(SignedPreKey.class); - assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature()); - assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId()); - assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey()); + assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result); } @Test @@ -237,9 +247,7 @@ class KeysControllerTest { .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(SignedPreKey.class); - assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getSignature()); - assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getKeyId()); - assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getPublicKey()); + assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result); } @Test @@ -291,19 +299,63 @@ class KeysControllerTest { @Test void validSingleRequestTestV2() { PreKeyResponse result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(PreKeyResponse.class); + .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); - assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); + assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); + assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); - verify(KEYS).take(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, 1); + verifyNoMoreInteractions(KEYS); + } + + @Test + void validSingleRequestPqTestNoPqKeysV2() { + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.empty()); + + PreKeyResponse result = resources.getJerseyTest() + .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) + .queryParam("pq", "true") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); + + assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); + assertThat(result.getDevicesCount()).isEqualTo(1); + assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertThat(result.getDevice(1).getPqPreKey()).isNull(); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + + verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).takePQ(EXISTS_UUID, 1); + verifyNoMoreInteractions(KEYS); + } + + @Test + void validSingleRequestPqTestV2() { + PreKeyResponse result = resources.getJerseyTest() + .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) + .queryParam("pq", "true") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); + + assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); + assertThat(result.getDevicesCount()).isEqualTo(1); + assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); + + verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).takePQ(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -317,12 +369,33 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); - assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); + assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); + assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); - verify(KEYS).take(EXISTS_PNI, 1); + verify(KEYS).takeEC(EXISTS_PNI, 1); + verifyNoMoreInteractions(KEYS); + } + + @Test + void validSingleRequestPqByPhoneNumberIdentifierTestV2() { + PreKeyResponse result = resources.getJerseyTest() + .target(String.format("/v2/keys/%s/1", EXISTS_PNI)) + .queryParam("pq", "true") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); + + assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); + assertThat(result.getDevicesCount()).isEqualTo(1); + assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); + assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); + assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); + + verify(KEYS).takeEC(EXISTS_PNI, 1); + verify(KEYS).takePQ(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); } @@ -338,12 +411,12 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); - assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); + assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); + assertThat(result.getDevice(1).getPqPreKey()).isNull(); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); + assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey()); - verify(KEYS).take(EXISTS_PNI, 1); + verify(KEYS).takeEC(EXISTS_PNI, 1); verifyNoMoreInteractions(KEYS); } @@ -365,18 +438,20 @@ class KeysControllerTest { @Test void testUnidentifiedRequest() { PreKeyResponse result = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) - .request() - .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) - .get(PreKeyResponse.class); + .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) + .queryParam("pq", "true") + .request() + .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) + .get(PreKeyResponse.class); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getDevicesCount()).isEqualTo(1); - assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); - assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); - assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); + assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey()); + assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); + assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey()); - verify(KEYS).take(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).takePQ(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -422,59 +497,118 @@ class KeysControllerTest { @Test void validMultiRequestTestV2() { - when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.take(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); - when(KEYS.take(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); - when(KEYS.take(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); + when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); + when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); + when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); PreKeyResponse results = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) - .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(PreKeyResponse.class); + .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); - PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); - PreKey preKey = results.getDevice(1).getPreKey(); - long registrationId = results.getDevice(1).getRegistrationId(); - long deviceId = results.getDevice(1).getDeviceId(); + PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); + PreKey preKey = results.getDevice(1).getPreKey(); + long registrationId = results.getDevice(1).getRegistrationId(); + long deviceId = results.getDevice(1).getDeviceId(); - assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); - assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); + assertKeysMatch(SAMPLE_KEY, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); - assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY.getKeyId()); - assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY.getPublicKey()); + assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); assertThat(deviceId).isEqualTo(1); - signedPreKey = results.getDevice(2).getSignedPreKey(); - preKey = results.getDevice(2).getPreKey(); + signedPreKey = results.getDevice(2).getSignedPreKey(); + preKey = results.getDevice(2).getPreKey(); registrationId = results.getDevice(2).getRegistrationId(); - deviceId = results.getDevice(2).getDeviceId(); + deviceId = results.getDevice(2).getDeviceId(); - assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId()); - assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey()); + assertKeysMatch(SAMPLE_KEY2, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); - assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY2.getKeyId()); - assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY2.getPublicKey()); + assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); assertThat(deviceId).isEqualTo(2); - signedPreKey = results.getDevice(4).getSignedPreKey(); - preKey = results.getDevice(4).getPreKey(); + signedPreKey = results.getDevice(4).getSignedPreKey(); + preKey = results.getDevice(4).getPreKey(); registrationId = results.getDevice(4).getRegistrationId(); - deviceId = results.getDevice(4).getDeviceId(); + deviceId = results.getDevice(4).getDeviceId(); - assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId()); - assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey()); + assertKeysMatch(SAMPLE_KEY4, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(signedPreKey).isNull(); assertThat(deviceId).isEqualTo(4); - verify(KEYS).take(EXISTS_UUID, 1); - verify(KEYS).take(EXISTS_UUID, 2); - verify(KEYS).take(EXISTS_UUID, 3); - verify(KEYS).take(EXISTS_UUID, 4); + verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, 2); + verify(KEYS).takeEC(EXISTS_UUID, 4); + verifyNoMoreInteractions(KEYS); + } + + @Test + void validMultiRequestPqTestV2() { + when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); + when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); + when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); + when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY)); + when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2)); + when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3)); + when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.empty()); + + PreKeyResponse results = resources.getJerseyTest() + .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) + .queryParam("pq", "true") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get(PreKeyResponse.class); + + assertThat(results.getDevicesCount()).isEqualTo(3); + assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); + + PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); + PreKey preKey = results.getDevice(1).getPreKey(); + SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey(); + long registrationId = results.getDevice(1).getRegistrationId(); + long deviceId = results.getDevice(1).getDeviceId(); + + assertKeysMatch(SAMPLE_KEY, preKey); + assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey); + assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); + assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey); + assertThat(deviceId).isEqualTo(1); + + signedPreKey = results.getDevice(2).getSignedPreKey(); + preKey = results.getDevice(2).getPreKey(); + pqPreKey = results.getDevice(2).getPqPreKey(); + registrationId = results.getDevice(2).getRegistrationId(); + deviceId = results.getDevice(2).getDeviceId(); + + assertThat(preKey).isNull(); + assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey); + assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); + assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey); + assertThat(deviceId).isEqualTo(2); + + signedPreKey = results.getDevice(4).getSignedPreKey(); + preKey = results.getDevice(4).getPreKey(); + pqPreKey = results.getDevice(4).getPqPreKey(); + registrationId = results.getDevice(4).getRegistrationId(); + deviceId = results.getDevice(4).getDeviceId(); + + assertKeysMatch(SAMPLE_KEY4, preKey); + assertThat(pqPreKey).isNull(); + assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); + assertThat(signedPreKey).isNull(); + assertThat(deviceId).isEqualTo(4); + + verify(KEYS).takeEC(EXISTS_UUID, 1); + verify(KEYS).takePQ(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, 2); + verify(KEYS).takePQ(EXISTS_UUID, 2); + verify(KEYS).takeEC(EXISTS_UUID, 4); + verify(KEYS).takePQ(EXISTS_UUID, 4); verifyNoMoreInteractions(KEYS); } @@ -523,16 +657,12 @@ class KeysControllerTest { @Test void putKeysTestV2() { - final PreKey preKey = new PreKey(31337, "foobar"); + final PreKey preKey = new PreKey(31337, "foobar"); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); - final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); + final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); - List preKeys = new LinkedList() {{ - add(preKey); - }}; - - PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys); + PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); Response response = resources.getJerseyTest() @@ -544,12 +674,41 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture()); + verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); - List capturedList = listCaptor.getValue(); - assertThat(capturedList.size()).isEqualTo(1); - assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); - assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar"); + assertThat(listCaptor.getValue()).containsExactly(preKey); + + verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey)); + verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); + verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + } + + @Test + void putKeysPqTestV2() { + final PreKey preKey = new PreKey(31337, "foobar"); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); + final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair); + final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair); + final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); + + PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); + + Response response = + resources.getJerseyTest() + .target("/v2/keys") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(204); + + ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); + verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); + + assertThat(ecCaptor.getValue()).containsExactly(preKey); + assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); @@ -558,13 +717,12 @@ class KeysControllerTest { @Test void putKeysByPhoneNumberIdentifierTestV2() { + final PreKey preKey = new PreKey(31337, "foobar"); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); - final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); + final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); - List preKeys = List.of(new PreKey(31337, "foobar")); - - PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys); + PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey)); Response response = resources.getJerseyTest() @@ -577,12 +735,42 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture()); + verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull()); - List capturedList = listCaptor.getValue(); - assertThat(capturedList.size()).isEqualTo(1); - assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); - assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar"); + assertThat(listCaptor.getValue()).containsExactly(preKey); + + verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey)); + verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); + verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + } + + @Test + void putKeysByPhoneNumberIdentifierPqTestV2() { + final PreKey preKey = new PreKey(31337, "foobar"); + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); + final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair); + final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair); + final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); + + PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey); + + Response response = + resources.getJerseyTest() + .target("/v2/keys") + .queryParam("identity", "pni") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); + + assertThat(response.getStatus()).isEqualTo(204); + + ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); + verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey)); + + assertThat(ecCaptor.getValue()).containsExactly(preKey); + assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); @@ -627,7 +815,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture()); + verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -657,4 +845,13 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(403); } + + private void assertKeysMatch(PreKey expected, PreKey actual) { + assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId()); + assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey()); + if (expected instanceof final SignedPreKey signedExpected) { + final SignedPreKey signedActual = (SignedPreKey) actual; + assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature()); + } + } }