PQXDH endpoints for chat server

This commit is contained in:
Jonathan Klabunde Tomer 2023-05-16 17:34:33 -04:00 committed by GitHub
parent 34d77e73ff
commit caae27c44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1378 additions and 380 deletions

View File

@ -53,8 +53,12 @@ dynamoDbTables:
tableName: Example_IssuedReceipts
expiration: P30D # Duration of time until rows expire
generator: abcdefg12345678= # random base64-encoded binary sequence
keys:
ecKeys:
tableName: Example_Keys
pqKeys:
tableName: Example_PQ_Keys
pqLastResortKeys:
tableName: Example_PQ_Last_Resort_Keys
messages:
tableName: Example_Messages
expiration: P30D # Duration of time until rows expire

View File

@ -341,7 +341,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, config.getDynamoDbTables().getKeys().getTableName());
Keys keys = new Keys(dynamoDbClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getPqKeys().getTableName(),
config.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -50,7 +50,9 @@ public class DynamoDbTables {
private final Table deletedAccounts;
private final Table deletedAccountsLock;
private final IssuedReceiptsTableConfiguration issuedReceipts;
private final Table keys;
private final Table ecKeys;
private final Table pqKeys;
private final Table pqLastResortKeys;
private final TableWithExpiration messages;
private final Table pendingAccounts;
private final Table pendingDevices;
@ -69,7 +71,9 @@ public class DynamoDbTables {
@JsonProperty("deletedAccounts") final Table deletedAccounts,
@JsonProperty("deletedAccountsLock") final Table deletedAccountsLock,
@JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts,
@JsonProperty("keys") final Table keys,
@JsonProperty("ecKeys") final Table ecKeys,
@JsonProperty("pqKeys") final Table pqKeys,
@JsonProperty("pqLastResortKeys") final Table pqLastResortKeys,
@JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("pendingAccounts") final Table pendingAccounts,
@JsonProperty("pendingDevices") final Table pendingDevices,
@ -87,7 +91,9 @@ public class DynamoDbTables {
this.deletedAccounts = deletedAccounts;
this.deletedAccountsLock = deletedAccountsLock;
this.issuedReceipts = issuedReceipts;
this.keys = keys;
this.ecKeys = ecKeys;
this.pqKeys = pqKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.messages = messages;
this.pendingAccounts = pendingAccounts;
this.pendingDevices = pendingDevices;
@ -128,8 +134,20 @@ public class DynamoDbTables {
@NotNull
@Valid
public Table getKeys() {
return keys;
public Table getEcKeys() {
return ecKeys;
}
@NotNull
@Valid
public Table getPqKeys() {
return pqKeys;
}
@NotNull
@Valid
public Table getPqLastResortKeys() {
return pqLastResortKeys;
}
@NotNull

View File

@ -493,6 +493,7 @@ public class AccountController {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@ -128,6 +128,7 @@ public class AccountControllerV2 {
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());
@ -172,10 +173,11 @@ public class AccountControllerV2 {
}
try {
final Account updatedAccount = changeNumberManager.updatePNIKeys(
final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedAccount.getAccount(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());

View File

@ -11,14 +11,21 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@ -75,12 +82,14 @@ public class KeysController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Returns the number of available one-time prekeys for this device")
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {
int count = keys.getCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(count);
return new PreKeyCount(ecCount, pqCount);
}
@Timed
@ -88,9 +97,17 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Sets the identity key for the account or phone-number identity and/or prekeys for this device")
public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@NotNull @Valid final PreKeyState preKeys,
@RequestBody @NotNull @Valid final PreKeyState preKeys,
@Parameter(allowEmptyValue=true)
@Schema(
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
@ -98,7 +115,8 @@ public class KeysController {
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType);
if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
if (preKeys.getSignedPreKey() != null &&
!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
updateAccount = true;
}
@ -121,13 +139,15 @@ public class KeysController {
if (updateAccount) {
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
if (preKeys.getSignedPreKey() != null) {
a.getDevice(device.getId()).ifPresent(d -> {
if (usePhoneNumberIdentity) {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
} else {
d.setSignedPreKey(preKeys.getSignedPreKey());
}
});
}
if (usePhoneNumberIdentity) {
a.setPhoneNumberIdentityKey(preKeys.getIdentityKey());
@ -137,17 +157,29 @@ public class KeysController {
});
}
keys.store(getIdentifier(account, identityType), device.getId(), preKeys.getPreKeys());
keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getPqLastResortPreKey());
}
@Timed
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Retrieves the public identity key and available device prekeys for a specified account or phone-number identity")
public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") UUID targetUuid,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId,
@Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys")
@Schema(defaultValue="false")
@QueryParam("pq") boolean returnPqKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
@ -175,28 +207,30 @@ public class KeysController {
final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid);
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target, deviceId, usePhoneNumberIdentity);
List<PreKeyResponseItem> responseItems = new LinkedList<>();
List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
for (Device device : target.getDevices()) {
if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey preKey = preKeysByDeviceId.get(device.getId());
for (Device device : devices) {
UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
if (signedPreKey != null || preKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey));
}
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
}
}
final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey();
if (responseItems.isEmpty()) return Response.status(404).build();
else return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
if (responseItems.isEmpty()) {
return Response.status(404).build();
}
return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
}
@Timed
@ -243,31 +277,15 @@ public class KeysController {
account.getUuid();
}
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector, final boolean usePhoneNumberIdentity) {
final Map<Long, PreKey> preKeys;
final UUID identifier = usePhoneNumberIdentity ?
destination.getPhoneNumberIdentifier() :
destination.getUuid();
if (deviceIdSelector.equals("*")) {
preKeys = new HashMap<>();
for (final Device device : destination.getDevices()) {
keys.take(identifier, device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey));
}
} else {
try {
long deviceId = Long.parseLong(deviceIdSelector);
preKeys = keys.take(identifier, deviceId)
.map(preKey -> Map.of(deviceId, preKey))
.orElse(Collections.emptyMap());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();
}
try {
long id = Long.parseLong(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
return preKeys;
}
}

View File

@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
@ -16,21 +18,57 @@ import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record ChangeNumberRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotBlank String number,
@JsonProperty("reglock") @Nullable String registrationLock,
@NotBlank String pniIdentityKey,
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
public record ChangeNumberRequest(
@Schema(description="""
A session ID from registration service, if using session id to authenticate this request.
Must not be combined with `recoveryPassword`.""")
String sessionId,
@Schema(description="""
The recovery password for the new phone number, if using a recovery password to authenticate this request.
Must not be combined with `sessionId`.""")
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@Schema(description="the new phone number for this account")
@NotBlank String number,
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@NotBlank String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) {
return true;
List<SignedPreKey> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@ -6,27 +6,61 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.validation.constraints.AssertTrue;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
public record ChangePhoneNumberRequest(@NotBlank String number,
@NotBlank String code,
@JsonProperty("reglock") @Nullable String registrationLock,
@Nullable String pniIdentityKey,
@Nullable List<IncomingMessage> deviceMessages,
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Nullable Map<Long, Integer> pniRegistrationIds) {
public record ChangePhoneNumberRequest(
@Schema(description="the new phone number for this account")
@NotBlank String number,
@Schema(description="the registration verification code to authenticate this request")
@NotBlank String code,
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@Nullable String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@Nullable List<IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@Nullable Map<Long, Integer> pniRegistrationIds) {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) {
return true;
List<SignedPreKey> spks = new ArrayList<>();
if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
}
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
@ -17,29 +18,45 @@ import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record PhoneNumberIdentityKeyDistributionRequest(
@NotBlank
@Schema(description="the new identity key for this account's phone-number identity")
String pniIdentityKey,
@NotBlank
@Schema(description="the new identity key for this account's phone-number identity")
String pniIdentityKey,
@NotNull
@Valid
@Schema(description="A message for each companion device to pass its new private keys")
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@Valid
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@Valid
@Schema(description="The public key of a new signed elliptic-curve prekey pair for each device")
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull
@Valid
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device")
Map<Long, Integer> pniRegistrationIds) {
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device")
Map<Long, Integer> pniRegistrationIds) {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {
return devicePniSignedPrekeys.values().parallelStream()
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk));
List<SignedPreKey> spks = new ArrayList<>(devicePniSignedPrekeys.values());
if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
}
}

View File

@ -13,17 +13,17 @@ public class PreKey {
@JsonProperty
@NotNull
private long keyId;
private long keyId;
@JsonProperty
@NotEmpty
private String publicKey;
private String publicKey;
public PreKey() {}
public PreKey(long keyId, String publicKey)
{
this.keyId = keyId;
this.keyId = keyId;
this.publicKey = publicKey;
}
@ -63,5 +63,4 @@ public class PreKey {
return ((int)this.keyId) ^ publicKey.hashCode();
}
}
}

View File

@ -5,16 +5,22 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyCount {
@Schema(description="the number of stored unsigned elliptic-curve prekeys for this device")
@JsonProperty
private int count;
public PreKeyCount(int count) {
this.count = count;
@Schema(description="the number of stored one-time post-quantum prekeys for this device")
@JsonProperty
private int pqCount;
public PreKeyCount(int ecCount, int pqCount) {
this.count = ecCount;
this.pqCount = pqCount;
}
public PreKeyCount() {}
@ -22,4 +28,8 @@ public class PreKeyCount {
public int getCount() {
return count;
}
public int getPqCount() {
return pqCount;
}
}

View File

@ -7,15 +7,18 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
public class PreKeyResponse {
@JsonProperty
@Schema(description="the public identity key for the requested identity")
private String identityKey;
@JsonProperty
@Schema(description="information about each requested device")
private List<PreKeyResponseItem> devices;
public PreKeyResponse() {}

View File

@ -6,28 +6,39 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyResponseItem {
@JsonProperty
@Schema(description="the device ID of the device to which this item pertains")
private long deviceId;
@JsonProperty
@Schema(description="the registration ID for the device")
private int registrationId;
@JsonProperty
@Schema(description="the signed elliptic-curve prekey for the device, if one has been set")
private SignedPreKey signedPreKey;
@JsonProperty
@Schema(description="an unsigned elliptic-curve prekey for the device, if any remain")
private PreKey preKey;
@JsonProperty
@Schema(description="a signed post-quantum prekey for the device " +
"(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)")
private SignedPreKey pqPreKey;
public PreKeyResponseItem() {}
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey) {
this.deviceId = deviceId;
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) {
this.deviceId = deviceId;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
this.preKey = preKey;
this.signedPreKey = signedPreKey;
this.preKey = preKey;
this.pqPreKey = pqPreKey;
}
@VisibleForTesting
@ -40,6 +51,11 @@ public class PreKeyResponseItem {
return preKey;
}
@VisibleForTesting
public SignedPreKey getPqPreKey() {
return pqPreKey;
}
@VisibleForTesting
public int getRegistrationId() {
return registrationId;

View File

@ -5,24 +5,38 @@
package org.whispersystems.textsecuregcm.entities;
import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.util.Base64;
import java.util.Collection;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
public abstract class PreKeySignatureValidator {
public static final boolean validatePreKeySignature(final String identityKeyB64, final SignedPreKey spk) {
public static final Counter INVALID_SIGNATURE_COUNTER =
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature"));
public static final boolean validatePreKeySignatures(final String identityKeyB64, final Collection<SignedPreKey> spks) {
try {
final byte[] identityKeyBytes = Base64.getDecoder().decode(identityKeyB64);
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0);
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes);
final boolean success = spks.stream().allMatch(spk -> {
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes);
});
if (!success) {
INVALID_SIGNATURE_COUNTER.increment();
}
return success;
} catch (IllegalArgumentException | InvalidKeyException e) {
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")).increment();
INVALID_SIGNATURE_COUNTER.increment();
return false;
}
}

View File

@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
@ -15,26 +17,59 @@ import javax.validation.constraints.NotNull;
public class PreKeyState {
@JsonProperty
@NotNull
@Valid
@Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " +
"If present and not empty, replaces all stored unsigned EC prekeys for the device; " +
"if absent or empty, any stored unsigned EC prekeys for the device are not deleted.")
private List<PreKey> preKeys;
@JsonProperty
@NotNull
@Valid
@Schema(description="An optional signed elliptic-curve prekey to use for this device. " +
"If present, replaces the stored signed elliptic-curve prekey for the device; " +
"if absent, the stored signed prekey is not deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey signedPreKey;
@JsonProperty
@Valid
@Schema(description="A list of signed post-quantum one-time prekeys to use for this device. " +
"Each key must have a valid signature from the identity key in this request. " +
"If present and not empty, replaces all stored unsigned PQ prekeys for the device; " +
"if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.")
private List<SignedPreKey> pqPreKeys;
@JsonProperty
@Valid
@Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " +
"If present, replaces the stored signed post-quantum last-resort prekey for the device; " +
"if absent, a stored last-resort prekey will *not* be deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey pqLastResortPreKey;
@JsonProperty
@NotEmpty
@NotNull
@Schema(description="Required. " +
"The public identity key for this identity (account or phone-number identity). " +
"If this device is not the primary device for the account, " +
"must match the existing stored identity key for this identity.")
private String identityKey;
public PreKeyState() {}
@VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys) {
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this(identityKey, signedPreKey, keys, null, null);
}
@VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) {
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this.pqPreKeys = pqKeys;
this.pqLastResortPreKey = pqLastResortKey;
}
public List<PreKey> getPreKeys() {
@ -45,12 +80,30 @@ public class PreKeyState {
return signedPreKey;
}
public List<SignedPreKey> getPqPreKeys() {
return pqPreKeys;
}
public SignedPreKey getPqLastResortPreKey() {
return pqLastResortPreKey;
}
public String getIdentityKey() {
return identityKey;
}
@AssertTrue
public boolean isSignatureValid() {
return PreKeySignatureValidator.validatePreKeySignature(identityKey, signedPreKey);
public boolean isSignatureValidOnEachSignedKey() {
List<SignedPreKey> spks = new ArrayList<>();
if (pqPreKeys != null) {
spks.addAll(pqPreKeys);
}
if (pqLastResortPreKey != null) {
spks.add(pqLastResortPreKey);
}
if (signedPreKey != null) {
spks.add(signedPreKey);
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks);
}
}

View File

@ -45,5 +45,4 @@ public class SignedPreKey extends PreKey {
return super.hashCode() ^ signature.hashCode();
}
}
}

View File

@ -12,6 +12,7 @@ import static io.micrometer.core.instrument.Metrics.timer;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@ -53,7 +54,7 @@ public abstract class AbstractDynamoDbStore {
return dynamoDbClient;
}
protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) {
protected void executeTableWriteItemsUntilComplete(final Map<String, ? extends Collection<WriteRequest>> items) {
final AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome);
int attemptCount = 0;
@ -80,7 +81,7 @@ public abstract class AbstractDynamoDbStore {
}
private void writeAndStoreOutcome(
final Map<String, List<WriteRequest>> items,
final Map<String, ? extends Collection<WriteRequest>> items,
final Timer timer,
final AtomicReference<BatchWriteItemResponse> outcome) {
timer.record(

View File

@ -245,6 +245,7 @@ public class AccountsManager {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber();
@ -252,12 +253,12 @@ public class AccountsManager {
if (originalNumber.equals(number)) {
if (pniIdentityKey != null) {
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePNIKeys");
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
}
return account;
}
validateDevices(account, pniSignedPreKeys, pniRegistrationIds);
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
@ -281,7 +282,7 @@ public class AccountsManager {
numberChangedAccount = updateWithRetries(
account,
a -> setPNIKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds),
a -> { setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); return true; },
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
@ -291,45 +292,74 @@ public class AccountsManager {
keys.delete(phoneNumberIdentifier);
keys.delete(originalPhoneNumberIdentifier);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(
phoneNumberIdentifier,
keys.getPqEnabledDevices(uuid).stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
}
return displacedUuid;
});
return updatedAccount.get();
}
public Account updatePNIKeys(final Account account,
public Account updatePniKeys(final Account account,
final String pniIdentityKey,
final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniRegistrationIds);
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
return update(account, a -> { return setPNIKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni);
keys.delete(pni);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
}
return updatedAccount;
}
private boolean setPNIKeys(final Account account,
private boolean setPniKeys(final Account account,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return true;
return false;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
}
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
boolean changed = !pniIdentityKey.equals(account.getPhoneNumberIdentityKey());
for (Device device : account.getDevices()) {
if (!device.isEnabled()) {
continue;
}
SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId());
changed = changed ||
!signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) ||
device.getRegistrationId() != registrationId;
device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
device.setPhoneNumberIdentityRegistrationId(registrationId);
}
pniRegistrationIds.forEach((deviceId, registrationId) ->
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
return true;
return changed;
}
private void validateDevices(final Account account,
final Map<Long, SignedPreKey> pniSignedPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
@ -342,6 +372,12 @@ public class AccountsManager {
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all including master ID are in Pq pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList(
account,

View File

@ -42,6 +42,7 @@ public class ChangeNumberManager {
public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
@ -62,10 +63,14 @@ public class ChangeNumberManager {
// We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key
// material and distribute the sync messages, to be sure all clients agree with us and each other about what their
// keys are. Pretend this change-number request was actually a PNI key distribution request.
return updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds);
if (pniIdentityKey == null) {
return account;
}
return updatePniKeys(account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, deviceMessages, pniRegistrationIds);
}
final Account updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
final Account updatedAccount = accountsManager.changeNumber(
account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
if (deviceMessages != null) {
sendDeviceMessages(updatedAccount, deviceMessages);
@ -74,16 +79,18 @@ public class ChangeNumberManager {
return updatedAccount;
}
public Account updatePNIKeys(final Account account,
public Account updatePniKeys(final Account account,
final String pniIdentityKey,
final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
validateDeviceMessages(account, deviceMessages);
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations.
final Account updatedAccount = accountsManager.updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds);
final Account updatedAccount = accountsManager.updatePniKeys(
account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
sendDeviceMessages(updatedAccount, deviceMessages);
return updatedAccount;

View File

@ -6,6 +6,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
@ -16,7 +19,11 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -34,11 +41,14 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class Keys extends AbstractDynamoDbStore {
private final String tableName;
private final String ecTableName;
private final String pqTableName;
private final String pqLastResortTableName;
static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String KEY_PUBLIC_KEY = "P";
static final String KEY_SIGNATURE = "S";
private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys"));
private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice"));
@ -48,31 +58,114 @@ public class Keys extends AbstractDynamoDbStore {
private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys"));
private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount"));
private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty"));
private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys"));
public Keys(final DynamoDbClient dynamoDB, final String tableName) {
public Keys(
final DynamoDbClient dynamoDB,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
super(dynamoDB);
this.tableName = tableName;
this.ecTableName = ecTableName;
this.pqTableName = pqTableName;
this.pqLastResortTableName = pqLastResortTableName;
}
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
STORE_KEYS_TIMER.record(() -> {
delete(identifier, deviceId);
store(identifier, deviceId, keys, null, null);
}
writeInBatches(keys, batch -> {
List<WriteRequest> items = new ArrayList<>();
for (final PreKey preKey : batch) {
items.add(WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.build());
}
executeTableWriteItemsUntilComplete(Map.of(tableName, items));
public void store(
final UUID identifier, final long deviceId,
@Nullable final List<PreKey> ecKeys,
@Nullable final List<SignedPreKey> pqKeys,
@Nullable final SignedPreKey pqLastResortKey) {
Multimap<String, PreKey> keys = MultimapBuilder.hashKeys().arrayListValues().build();
List<String> tablesToClear = new ArrayList<>();
if (ecKeys != null && !ecKeys.isEmpty()) {
keys.putAll(ecTableName, ecKeys);
tablesToClear.add(ecTableName);
}
if (pqKeys != null && !pqKeys.isEmpty()) {
keys.putAll(pqTableName, pqKeys);
tablesToClear.add(pqTableName);
}
if (pqLastResortKey != null) {
keys.put(pqLastResortTableName, pqLastResortKey);
tablesToClear.add(pqLastResortTableName);
}
STORE_KEYS_TIMER.record(() -> {
delete(tablesToClear, identifier, deviceId);
writeInBatches(
keys.entries(),
batch -> {
Multimap<String, WriteRequest> writes = batch.stream()
.collect(
Multimaps.toMultimap(
Map.Entry<String, PreKey>::getKey,
entry -> WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, entry.getValue()))
.build())
.build(),
MultimapBuilder.hashKeys().arrayListValues()::build));
executeTableWriteItemsUntilComplete(writes.asMap());
});
});
}
public Optional<PreKey> take(final UUID identifier, final long deviceId) {
public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
final List<WriteRequest> writes = new ArrayList<>(2 * keys.size());
final Map<Long, Map<String, AttributeValue>> newItems = keys.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue())));
for (final Map<String, AttributeValue> item : db().query(queryRequest).items()) {
final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID);
final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong();
if (newItems.containsKey(oldDeviceId)) {
final Map<String, AttributeValue> replacement = newItems.get(oldDeviceId);
if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) {
writes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, oldSortKey))
.build())
.build());
}
}
}
newItems.forEach((unusedKey, item) ->
writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build()));
executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes));
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return take(ecTableName, identifier, deviceId);
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return take(pqTableName, identifier, deviceId)
.or(() -> getLastResort(identifier, deviceId))
.map(pk -> (SignedPreKey) pk);
}
private Optional<PreKey> take(final String tableName, final UUID identifier, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
@ -114,7 +207,53 @@ public class Keys extends AbstractDynamoDbStore {
});
}
public int getCount(final UUID identifier, final long deviceId) {
@VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.consistentRead(false)
.select(Select.ALL_ATTRIBUTES)
.build();
QueryResponse response = db().query(queryRequest);
if (response.count() > 1) {
TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment();
}
return response.items().stream().findFirst().map(this::getPreKeyFromItem);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
final QueryResponse response = db().query(queryRequest);
return response.items().stream()
.map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong())
.toList();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return getCount(ecTableName, identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return getCount(pqTableName, identifier, deviceId);
}
private int getCount(final String tableName, final UUID identifier, final long deviceId) {
return GET_KEY_COUNT_TIMER.record(() -> {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
@ -144,51 +283,66 @@ public class Keys extends AbstractDynamoDbStore {
public void delete(final UUID accountUuid) {
DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid)))
":uuid", getPartitionKey(accountUuid)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest);
deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest);
});
}
public void delete(final UUID accountUuid, final long deviceId) {
delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId);
}
private void delete(final List<String> tableNames, final UUID accountUuid, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest);
deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest);
});
}
private void deleteItemsForAccountMatchingQuery(final UUID accountUuid, final QueryRequest querySpec) {
private void deleteItemsForAccountMatchingQuery(final List<String> tableNames, final UUID accountUuid, final QueryRequest querySpec) {
final AttributeValue partitionKey = getPartitionKey(accountUuid);
writeInBatches(db().query(querySpec).items(), batch -> {
List<WriteRequest> deletes = new ArrayList<>();
for (final Map<String, AttributeValue> item : batch) {
deletes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build());
}
executeTableWriteItemsUntilComplete(Map.of(tableName, deletes));
Multimap<String, Map<String, AttributeValue>> itemStream = tableNames.stream()
.collect(
Multimaps.flatteningToMultimap(
Function.identity(),
tableName ->
db().query(querySpec.toBuilder().tableName(tableName).build())
.items()
.stream(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
writeInBatches(
itemStream.entries(),
batch -> {
Multimap<String, WriteRequest> deletes = batch.stream()
.collect(Multimaps.toMultimap(
Map.Entry<String, Map<String, AttributeValue>>::getKey,
entry -> WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
executeTableWriteItemsUntilComplete(deletes.asMap());
});
}
@ -211,6 +365,13 @@ public class Keys extends AbstractDynamoDbStore {
}
private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) {
if (preKey instanceof final SignedPreKey spk) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromString(spk.getPublicKey()),
KEY_SIGNATURE, AttributeValues.fromString(spk.getSignature()));
}
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
@ -219,6 +380,11 @@ public class Keys extends AbstractDynamoDbStore {
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s());
}
return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s());
}
}

View File

@ -174,7 +174,9 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -154,7 +154,9 @@ record CommandDependencies(
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName());
configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -334,7 +334,7 @@ class AccountControllerTest {
return account;
});
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final String pniIdentityKey = invocation.getArgument(2, String.class);
@ -358,7 +358,7 @@ class AccountControllerTest {
return updatedAccount;
});
when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String pniIdentityKey = invocation.getArgument(1, String.class);
@ -1377,12 +1377,12 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any());
assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.number()).isEqualTo(number);
@ -1399,12 +1399,12 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.readEntity(String.class)).isBlank();
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1417,7 +1417,7 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400);
@ -1426,7 +1426,7 @@ class AccountControllerTest {
assertThat(responseEntity.getOriginalNumber()).isEqualTo(number);
assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111");
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1436,10 +1436,10 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
}
@Test
@ -1454,11 +1454,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1478,13 +1478,13 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT);
assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1514,11 +1514,11 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
}
@Test
@ -1549,14 +1549,14 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
// verify(existingAccount).lockAuthenticationCredentials();
// verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1589,14 +1589,14 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
// verify(existingAccount).lockAuthenticationCredentials();
// verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
}
@Test
@ -1628,13 +1628,13 @@ class AccountControllerTest {
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null),
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
verify(senderRegLockAccount, never()).lockAuthTokenHash();
verify(clientPresenceManager, never()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
}
@Test
@ -1681,10 +1681,11 @@ class AccountControllerTest {
number, code, null,
pniIdentityKey, deviceMessages,
deviceKeys,
null,
registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any());
assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.number()).isEqualTo(number);
@ -1734,11 +1735,12 @@ class AccountControllerTest {
AuthHelper.VALID_NUMBER, code, null,
pniIdentityKey, deviceMessages,
deviceKeys,
null,
registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(
eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any());
eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any(), any());
verifyNoInteractions(rateLimiter);
verifyNoInteractions(pendingAccountsManager);

View File

@ -134,7 +134,7 @@ class AccountControllerV2Test {
void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer(
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
@ -180,11 +180,11 @@ class AccountControllerV2Test {
.put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123",
Collections.emptyList(),
Collections.emptyMap(), Collections.emptyMap()),
Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any());
any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
@ -203,11 +203,11 @@ class AccountControllerV2Test {
new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null,
"pni-identity-key",
Collections.emptyList(),
Collections.emptyMap(), Collections.emptyMap()),
Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(),
any());
any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number());
@ -365,7 +365,7 @@ class AccountControllerV2Test {
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any());
any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
@ -458,7 +458,7 @@ class AccountControllerV2Test {
@BeforeEach
void setUp() throws Exception {
when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer(
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String pniIdentityKey = invocation.getArgument(1, String.class);
@ -496,7 +496,7 @@ class AccountControllerV2Test {
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(requestJson()), AccountIdentityResponse.class);
verify(changeNumberManager).updatePNIKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any());
verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number());
@ -557,6 +557,7 @@ class AccountControllerV2Test {
"pniIdentityKey": "pni-identity-key",
"deviceMessages": [],
"devicePniSignedPrekeys": {},
"devicePniSignedPqPrekeys": {},
"pniRegistrationIds": {}
}
""";

View File

@ -128,7 +128,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account, secondNumber, null, null, null);
accountsManager.changeNumber(account, secondNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -161,7 +161,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, registrationIds);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -191,8 +191,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
account = accountsManager.changeNumber(account, secondNumber, null, null, null);
accountsManager.changeNumber(account, originalNumber, null, null, null);
account = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
accountsManager.changeNumber(account, originalNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isPresent());
assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow());
@ -217,7 +217,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber, null, null, null);
accountsManager.changeNumber(account, secondNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -231,7 +231,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null);
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null);
final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(),
new ArrayList<>());
@ -251,7 +251,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid();
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null);
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
@ -262,7 +262,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null);
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null, null);
assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));

View File

@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
@ -641,7 +642,7 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber, null, null, null);
account = accountsManager.changeNumber(account, targetNumber, null, null, null, null);
assertEquals(targetNumber, account.getNumber());
@ -656,7 +657,7 @@ class AccountsManagerTest {
final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, number, null, null, null);
account = accountsManager.changeNumber(account, number, null, null, null, null);
assertEquals(number, account.getNumber());
verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any());
@ -664,13 +665,13 @@ class AccountsManagerTest {
}
@Test
void testChangePhoneNumberSameNumberWithPNIData() {
void testChangePhoneNumberSameNumberWithPniData() {
final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber(
account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), Map.of(1L, 101)),
account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any());
@ -694,14 +695,60 @@ class AccountsManagerTest {
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber, null, null, null);
account = accountsManager.changeNumber(account, targetNumber, null, null, null, null);
assertEquals(targetNumber, account.getNumber());
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys).delete(originalPni);
verify(keys).delete(targetPni);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verifyNoMoreInteractions(keys);
}
@Test
void testChangePhoneNumberWithPqKeysExistingAccount() throws InterruptedException, MismatchedDevicesException {
doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty()))
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333";
final UUID existingAccountUuid = UUID.randomUUID();
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Map<Long, SignedPreKey> newSignedKeys = Map.of(
1L, new SignedPreKey(1L, "pub1", "sig1"),
2L, new SignedPreKey(2L, "pub2", "sig2"));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
1L, new SignedPreKey(3L, "pub3", "sig3"),
2L, new SignedPreKey(4L, "pub4", "sig4"));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds);
assertEquals(targetNumber, updatedAccount.getNumber());
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verify(keys).delete(originalPni);
verify(keys).getPqEnabledDevices(uuid);
verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, new SignedPreKey(3L, "pub3", "sig3"))));
verifyNoMoreInteractions(keys);
}
@Test
@ -716,7 +763,7 @@ class AccountsManagerTest {
}
@Test
void testPNIUpdate() throws MismatchedDevicesException {
void testPniUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
@ -730,7 +777,7 @@ class AccountsManagerTest {
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final Account updatedAccount = accountsManager.updatePNIKeys(account, "new-pni-identity-key", newSignedKeys, newRegistrationIds);
final Account updatedAccount = accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, null, newRegistrationIds);
// non-PNI stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid());
@ -750,7 +797,57 @@ class AccountsManagerTest {
verify(accounts).update(any());
verifyNoInteractions(deletedAccountsManager);
verifyNoInteractions(keys);
verify(keys).delete(oldPni);
}
@Test
void testPniPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
Map<Long, SignedPreKey> newSignedKeys = Map.of(
1L, new SignedPreKey(1L, "pub1", "sig1"),
2L, new SignedPreKey(2L, "pub2", "sig2"));
Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
1L, new SignedPreKey(3L, "pub3", "sig3"),
2L, new SignedPreKey(4L, "pub4", "sig4"));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final Account updatedAccount =
accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds);
// non-PNI-keys stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid());
assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertEquals(null, updatedAccount.getIdentityKey());
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)));
assertEquals(Map.of(1L, 101, 2L, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
assertEquals("new-pni-identity-key", updatedAccount.getPhoneNumberIdentityKey());
assertEquals(newSignedKeys,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey)));
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).update(any());
verifyNoInteractions(deletedAccountsManager);
verify(keys).delete(oldPni);
// only the pq key for the already-pq-enabled device should be saved
verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
}
@Test

View File

@ -47,7 +47,7 @@ public class ChangeNumberManagerTest {
updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
when(accountsManager.changeNumber(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
when(accountsManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
@ -70,7 +70,7 @@ public class ChangeNumberManagerTest {
return updatedAccount;
});
when(accountsManager.updatePNIKeys(any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
when(accountsManager.updatePniKeys(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final UUID uuid = account.getUuid();
@ -94,8 +94,8 @@ public class ChangeNumberManagerTest {
void changeNumberNoMessages() throws Exception {
Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null);
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@ -107,8 +107,8 @@ public class ChangeNumberManagerTest {
var prekeys = Map.of(1L, new SignedPreKey());
final String pniIdentityKey = "pni-identity-key";
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyMap());
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@ -139,9 +139,53 @@ public class ChangeNumberManagerTest {
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, List.of(msg), registrationIds);
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(Device.MASTER_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234";
final String changedE164 = "+18025551234";
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn(originalE164);
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -174,15 +218,16 @@ public class ChangeNumberManagerTest {
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, List.of(msg), registrationIds);
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds);
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -196,7 +241,7 @@ public class ChangeNumberManagerTest {
}
@Test
void updatePNIKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception {
void updatePniKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
@ -219,9 +264,49 @@ public class ChangeNumberManagerTest {
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePNIKeys(account, pniIdentityKey, prekeys, List.of(msg), registrationIds);
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds);
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(Device.MASTER_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@Test
void updatePniKeysSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -261,11 +346,11 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds));
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, null, messages, registrationIds));
}
@Test
void updatePNIKeysMismatchedRegistrationId() {
void updatePniKeysMismatchedRegistrationId() {
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234");
@ -291,7 +376,7 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePNIKeys(account, "pni-identity-key", preKeys, messages, registrationIds));
() -> changeNumberManager.updatePniKeys(account, "pni-identity-key", preKeys, null, messages, registrationIds));
}
@Test
@ -320,6 +405,6 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, messages, registrationIds));
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, null, messages, registrationIds));
}
}

View File

@ -69,7 +69,35 @@ public final class DynamoDbExtensionSchema {
.build()),
List.of(), List.of()),
KEYS("keys_test",
EC_KEYS("keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_KEYS("pq_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
List.of(

View File

@ -6,99 +6,244 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
class KeysTest {
private Keys keys;
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.KEYS);
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
@BeforeEach
void setup() {
keys = new Keys(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.KEYS.tableName());
keys = new Keys(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.PQ_LAST_RESORT_KEYS.tableName());
}
@Test
void testStore() {
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID),
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1, "pq-public-key", "sig")), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key")));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID),
keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, new SignedPreKey(1001, "pq-last-resort-key", "sig"));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")), null, null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key")), List.of(new SignedPreKey(2, "different-pq-public-key", "sig")), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(4, "fourth-public-key"), new PreKey(5, "fifth-public-key")),
List.of(new SignedPreKey(6, "sixth-pq-key", "sig"), new SignedPreKey(7, "seventh-pq-key", "sig")),
new SignedPreKey(1002, "new-last-resort-key", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
}
@Test
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.take(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = new PreKey(1, "public-key");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(Optional.of(preKey), keys.take(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = new SignedPreKey(1, "public-key", "sig");
final SignedPreKey preKey2 = new SignedPreKey(2, "different-public-key", "sig");
final SignedPreKey preKeyLast = new SignedPreKey(1001, "last-public-key", "sig");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testGetCount() {
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")), List.of(new SignedPreKey(1, "public-pq-key", "sig")), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testDeleteByAccount() {
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID);
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testDeleteByAccountAndDevice() {
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device")));
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, getLastResortCount(ACCOUNT_UUID));
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, new SignedPreKey(1L, "pub1", "sig1"), 2L, new SignedPreKey(2L, "pub2", "sig2")));
assertEquals(2, getLastResortCount(ACCOUNT_UUID));
assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, new SignedPreKey(3L, "pub3", "sig3"), 3L, new SignedPreKey(4L, "pub4", "sig4")));
assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
private int getLastResortCount(UUID uuid) {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(Tables.PQ_LAST_RESORT_KEYS.tableName())
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid)))
.select(Select.COUNT)
.build();
QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest);
return response.count();
}
@Test
void testGetPqEnabledDevices() {
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1L, "pub1", "sig1")), null);
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, new SignedPreKey(2L, "pub2", "sig2"));
keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(new SignedPreKey(3L, "pub3", "sig3")), new SignedPreKey(4L, "pub4", "sig4"));
keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID)));
}
@Test

View File

@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
@ -86,19 +87,25 @@ class KeysControllerTest {
private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final String PNI_IDENTITY_KEY = KeysHelper.serializeIdentityKey(PNI_IDENTITY_KEY_PAIR);
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6");
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6");
private final PreKey SAMPLE_KEY_PNI = new PreKey(7777, "test7");
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey( 1111, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey( 2222, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey( 3333, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey( 4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey( 5555, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey( 6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_PQ_KEY = new SignedPreKey(2424, "test1", "sig");
private final SignedPreKey SAMPLE_PQ_KEY2 = new SignedPreKey(6868, "test3", "sig");
private final SignedPreKey SAMPLE_PQ_KEY3 = new SignedPreKey(1313, "test5", "sig");
private final SignedPreKey SAMPLE_PQ_KEY_PNI = new SignedPreKey(8888, "test7", "sig");
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey(1111, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey(2222, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey(3333, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey(4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey(5555, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey(6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedPreKey(7777, PNI_IDENTITY_KEY_PAIR);
@ -177,10 +184,13 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI));
when(KEYS.getCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@ -210,8 +220,10 @@ class KeysControllerTest {
.get(PreKeyCount.class);
assertThat(result.getCount()).isEqualTo(5);
assertThat(result.getPqCount()).isEqualTo(5);
verify(KEYS).getCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1);
}
@ -223,9 +235,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature());
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result);
}
@Test
@ -237,9 +247,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class);
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getSignature());
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getPublicKey());
assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result);
}
@Test
@ -291,19 +299,63 @@ class KeysControllerTest {
@Test
void validSingleRequestTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.<SignedPreKey>empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@ -317,12 +369,33 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@ -338,12 +411,12 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey());
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@ -365,18 +438,20 @@ class KeysControllerTest {
@Test
void testUnidentifiedRequest() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@ -422,59 +497,118 @@ class KeysControllerTest {
@Test
void validMultiRequestTestV2() {
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2));
when(KEYS.take(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.take(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertKeysMatch(SAMPLE_KEY, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY.getPublicKey());
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertKeysMatch(SAMPLE_KEY2, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY2.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY2.getPublicKey());
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
deviceId = results.getDevice(4).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertKeysMatch(SAMPLE_KEY4, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(KEYS).take(EXISTS_UUID, 1);
verify(KEYS).take(EXISTS_UUID, 2);
verify(KEYS).take(EXISTS_UUID, 3);
verify(KEYS).take(EXISTS_UUID, 4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verifyNoMoreInteractions(KEYS);
}
@Test
void validMultiRequestPqTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.<SignedPreKey>empty());
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey);
assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
pqPreKey = results.getDevice(2).getPqPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey).isNull();
assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
pqPreKey = results.getDevice(4).getPqPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takePQ(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).takePQ(EXISTS_UUID, 4);
verifyNoMoreInteractions(KEYS);
}
@ -523,16 +657,12 @@ class KeysControllerTest {
@Test
void putKeysTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
List<PreKey> preKeys = new LinkedList<PreKey>() {{
add(preKey);
}};
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
Response response =
resources.getJerseyTest()
@ -544,12 +674,41 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture());
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void putKeysPqTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
@ -558,13 +717,12 @@ class KeysControllerTest {
@Test
void putKeysByPhoneNumberIdentifierTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
List<PreKey> preKeys = List.of(new PreKey(31337, "foobar"));
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
Response response =
resources.getJerseyTest()
@ -577,12 +735,42 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture());
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337);
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar");
assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void putKeysByPhoneNumberIdentifierPqTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.queryParam("identity", "pni")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
@ -627,7 +815,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture());
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@ -657,4 +845,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(403);
}
private void assertKeysMatch(PreKey expected, PreKey actual) {
assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId());
assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey());
if (expected instanceof final SignedPreKey signedExpected) {
final SignedPreKey signedActual = (SignedPreKey) actual;
assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature());
}
}
}