PQXDH endpoints for chat server

This commit is contained in:
Jonathan Klabunde Tomer 2023-05-16 17:34:33 -04:00 committed by GitHub
parent 34d77e73ff
commit caae27c44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1378 additions and 380 deletions

View File

@ -53,8 +53,12 @@ dynamoDbTables:
tableName: Example_IssuedReceipts tableName: Example_IssuedReceipts
expiration: P30D # Duration of time until rows expire expiration: P30D # Duration of time until rows expire
generator: abcdefg12345678= # random base64-encoded binary sequence generator: abcdefg12345678= # random base64-encoded binary sequence
keys: ecKeys:
tableName: Example_Keys tableName: Example_Keys
pqKeys:
tableName: Example_PQ_Keys
pqLastResortKeys:
tableName: Example_PQ_Last_Resort_Keys
messages: messages:
tableName: Example_Messages tableName: Example_Messages
expiration: P30D # Duration of time until rows expire expiration: P30D # Duration of time until rows expire

View File

@ -341,7 +341,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName()); config.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, config.getDynamoDbTables().getKeys().getTableName()); Keys keys = new Keys(dynamoDbClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getPqKeys().getTableName(),
config.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getMessages().getTableName(), config.getDynamoDbTables().getMessages().getTableName(),
config.getDynamoDbTables().getMessages().getExpiration(), config.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -50,7 +50,9 @@ public class DynamoDbTables {
private final Table deletedAccounts; private final Table deletedAccounts;
private final Table deletedAccountsLock; private final Table deletedAccountsLock;
private final IssuedReceiptsTableConfiguration issuedReceipts; private final IssuedReceiptsTableConfiguration issuedReceipts;
private final Table keys; private final Table ecKeys;
private final Table pqKeys;
private final Table pqLastResortKeys;
private final TableWithExpiration messages; private final TableWithExpiration messages;
private final Table pendingAccounts; private final Table pendingAccounts;
private final Table pendingDevices; private final Table pendingDevices;
@ -69,7 +71,9 @@ public class DynamoDbTables {
@JsonProperty("deletedAccounts") final Table deletedAccounts, @JsonProperty("deletedAccounts") final Table deletedAccounts,
@JsonProperty("deletedAccountsLock") final Table deletedAccountsLock, @JsonProperty("deletedAccountsLock") final Table deletedAccountsLock,
@JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts, @JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts,
@JsonProperty("keys") final Table keys, @JsonProperty("ecKeys") final Table ecKeys,
@JsonProperty("pqKeys") final Table pqKeys,
@JsonProperty("pqLastResortKeys") final Table pqLastResortKeys,
@JsonProperty("messages") final TableWithExpiration messages, @JsonProperty("messages") final TableWithExpiration messages,
@JsonProperty("pendingAccounts") final Table pendingAccounts, @JsonProperty("pendingAccounts") final Table pendingAccounts,
@JsonProperty("pendingDevices") final Table pendingDevices, @JsonProperty("pendingDevices") final Table pendingDevices,
@ -87,7 +91,9 @@ public class DynamoDbTables {
this.deletedAccounts = deletedAccounts; this.deletedAccounts = deletedAccounts;
this.deletedAccountsLock = deletedAccountsLock; this.deletedAccountsLock = deletedAccountsLock;
this.issuedReceipts = issuedReceipts; this.issuedReceipts = issuedReceipts;
this.keys = keys; this.ecKeys = ecKeys;
this.pqKeys = pqKeys;
this.pqLastResortKeys = pqLastResortKeys;
this.messages = messages; this.messages = messages;
this.pendingAccounts = pendingAccounts; this.pendingAccounts = pendingAccounts;
this.pendingDevices = pendingDevices; this.pendingDevices = pendingDevices;
@ -128,8 +134,20 @@ public class DynamoDbTables {
@NotNull @NotNull
@Valid @Valid
public Table getKeys() { public Table getEcKeys() {
return keys; return ecKeys;
}
@NotNull
@Valid
public Table getPqKeys() {
return pqKeys;
}
@NotNull
@Valid
public Table getPqLastResortKeys() {
return pqLastResortKeys;
} }
@NotNull @NotNull

View File

@ -493,6 +493,7 @@ public class AccountController {
request.number(), request.number(),
request.pniIdentityKey(), request.pniIdentityKey(),
request.devicePniSignedPrekeys(), request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(), request.deviceMessages(),
request.pniRegistrationIds()); request.pniRegistrationIds());

View File

@ -128,6 +128,7 @@ public class AccountControllerV2 {
request.number(), request.number(),
request.pniIdentityKey(), request.pniIdentityKey(),
request.devicePniSignedPrekeys(), request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(), request.deviceMessages(),
request.pniRegistrationIds()); request.pniRegistrationIds());
@ -172,10 +173,11 @@ public class AccountControllerV2 {
} }
try { try {
final Account updatedAccount = changeNumberManager.updatePNIKeys( final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedAccount.getAccount(), authenticatedAccount.getAccount(),
request.pniIdentityKey(), request.pniIdentityKey(),
request.devicePniSignedPrekeys(), request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
request.deviceMessages(), request.deviceMessages(),
request.pniRegistrationIds()); request.pniRegistrationIds());

View File

@ -11,14 +11,21 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.enums.ParameterIn;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -75,12 +82,14 @@ public class KeysController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Returns the number of available one-time prekeys for this device")
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) { @QueryParam("identity") final Optional<String> identityType) {
int count = keys.getCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); int ecCount = keys.getEcCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
int pqCount = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());
return new PreKeyCount(count); return new PreKeyCount(ecCount, pqCount);
} }
@Timed @Timed
@ -88,9 +97,17 @@ public class KeysController {
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState @ChangesDeviceEnabledState
@Operation(summary = "Sets the identity key for the account or phone-number identity and/or prekeys for this device")
public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth, public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@NotNull @Valid final PreKeyState preKeys, @RequestBody @NotNull @Valid final PreKeyState preKeys,
@Parameter(allowEmptyValue=true)
@Schema(
allowableValues={"aci", "pni"},
defaultValue="aci",
description="whether this operation applies to the account (aci) or phone-number (pni) identity")
@QueryParam("identity") final Optional<String> identityType, @QueryParam("identity") final Optional<String> identityType,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
Account account = disabledPermittedAuth.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
@ -98,7 +115,8 @@ public class KeysController {
final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType); final boolean usePhoneNumberIdentity = usePhoneNumberIdentity(identityType);
if (!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) { if (preKeys.getSignedPreKey() != null &&
!preKeys.getSignedPreKey().equals(usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey())) {
updateAccount = true; updateAccount = true;
} }
@ -121,13 +139,15 @@ public class KeysController {
if (updateAccount) { if (updateAccount) {
account = accounts.update(account, a -> { account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> { if (preKeys.getSignedPreKey() != null) {
if (usePhoneNumberIdentity) { a.getDevice(device.getId()).ifPresent(d -> {
d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey()); if (usePhoneNumberIdentity) {
} else { d.setPhoneNumberIdentitySignedPreKey(preKeys.getSignedPreKey());
d.setSignedPreKey(preKeys.getSignedPreKey()); } else {
} d.setSignedPreKey(preKeys.getSignedPreKey());
}); }
});
}
if (usePhoneNumberIdentity) { if (usePhoneNumberIdentity) {
a.setPhoneNumberIdentityKey(preKeys.getIdentityKey()); a.setPhoneNumberIdentityKey(preKeys.getIdentityKey());
@ -137,17 +157,29 @@ public class KeysController {
}); });
} }
keys.store(getIdentifier(account, identityType), device.getId(), preKeys.getPreKeys()); keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getPqLastResortPreKey());
} }
@Timed @Timed
@GET @GET
@Path("/{identifier}/{device_id}") @Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Retrieves the public identity key and available device prekeys for a specified account or phone-number identity")
public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth, public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@Parameter(description="the account or phone-number identifier to retrieve keys for")
@PathParam("identifier") UUID targetUuid, @PathParam("identifier") UUID targetUuid,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId, @PathParam("device_id") String deviceId,
@Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys")
@Schema(defaultValue="false")
@QueryParam("pq") boolean returnPqKey,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) @HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException { throws RateLimitExceededException {
@ -175,28 +207,30 @@ public class KeysController {
final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid); final boolean usePhoneNumberIdentity = target.getPhoneNumberIdentifier().equals(targetUuid);
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target, deviceId, usePhoneNumberIdentity); List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new LinkedList<>(); List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
for (Device device : target.getDevices()) { for (Device device : devices) {
if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) { UUID identifier = usePhoneNumberIdentity ? target.getPhoneNumberIdentifier() : targetUuid;
SignedPreKey signedPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey(); SignedPreKey signedECPreKey = usePhoneNumberIdentity ? device.getPhoneNumberIdentitySignedPreKey() : device.getSignedPreKey();
PreKey preKey = preKeysByDeviceId.get(device.getId()); PreKey unsignedECPreKey = keys.takeEC(identifier, device.getId()).orElse(null);
SignedPreKey pqPreKey = returnPqKey ? keys.takePQ(identifier, device.getId()).orElse(null) : null;
if (signedPreKey != null || preKey != null) { if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = usePhoneNumberIdentity ? final int registrationId = usePhoneNumberIdentity ?
device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) : device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()) :
device.getRegistrationId(); device.getRegistrationId();
responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedPreKey, preKey)); responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
}
} }
} }
final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey(); final String identityKey = usePhoneNumberIdentity ? target.getPhoneNumberIdentityKey() : target.getIdentityKey();
if (responseItems.isEmpty()) return Response.status(404).build(); if (responseItems.isEmpty()) {
else return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build(); return Response.status(404).build();
}
return Response.ok().entity(new PreKeyResponse(identityKey, responseItems)).build();
} }
@Timed @Timed
@ -243,31 +277,15 @@ public class KeysController {
account.getUuid(); account.getUuid();
} }
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector, final boolean usePhoneNumberIdentity) { private List<Device> parseDeviceId(String deviceId, Account account) {
final Map<Long, PreKey> preKeys; if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();
final UUID identifier = usePhoneNumberIdentity ? }
destination.getPhoneNumberIdentifier() : try {
destination.getUuid(); long id = Long.parseLong(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
if (deviceIdSelector.equals("*")) { } catch (NumberFormatException e) {
preKeys = new HashMap<>(); throw new WebApplicationException(Response.status(422).build());
for (final Device device : destination.getDevices()) {
keys.take(identifier, device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey));
}
} else {
try {
long deviceId = Long.parseLong(deviceIdSelector);
preKeys = keys.take(identifier, deviceId)
.map(preKey -> Map.of(deviceId, preKey))
.orElse(Collections.emptyMap());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
} }
return preKeys;
} }
} }

View File

@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -16,21 +18,57 @@ import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record ChangeNumberRequest(String sessionId, public record ChangeNumberRequest(
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, @Schema(description="""
@NotBlank String number, A session ID from registration service, if using session id to authenticate this request.
@JsonProperty("reglock") @Nullable String registrationLock, Must not be combined with `recoveryPassword`.""")
@NotBlank String pniIdentityKey, String sessionId,
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys, @Schema(description="""
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest { The recovery password for the new phone number, if using a recovery password to authenticate this request.
Must not be combined with `sessionId`.""")
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@Schema(description="the new phone number for this account")
@NotBlank String number,
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@NotBlank String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) { List<SignedPreKey> spks = new ArrayList<>();
return true; if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
} }
return devicePniSignedPrekeys.values().parallelStream() if (devicePniPqLastResortPrekeys != null) {
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
} }
} }

View File

@ -6,27 +6,61 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
public record ChangePhoneNumberRequest(@NotBlank String number, public record ChangePhoneNumberRequest(
@NotBlank String code, @Schema(description="the new phone number for this account")
@JsonProperty("reglock") @Nullable String registrationLock, @NotBlank String number,
@Nullable String pniIdentityKey,
@Nullable List<IncomingMessage> deviceMessages, @Schema(description="the registration verification code to authenticate this request")
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys, @NotBlank String code,
@Nullable Map<Long, Integer> pniRegistrationIds) {
@Schema(description="the registration lock password for the new phone number, if necessary")
@JsonProperty("reglock") @Nullable String registrationLock,
@Schema(description="the new public identity key to use for the phone-number identity associated with the new phone number")
@Nullable String pniIdentityKey,
@Schema(description="""
A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
@Nullable List<IncomingMessage> deviceMessages,
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable Map<Long, SignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Nullable @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@Nullable Map<Long, Integer> pniRegistrationIds) {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
if (devicePniSignedPrekeys == null) { List<SignedPreKey> spks = new ArrayList<>();
return true; if (devicePniSignedPrekeys != null) {
spks.addAll(devicePniSignedPrekeys.values());
} }
return devicePniSignedPrekeys.values().parallelStream() if (devicePniPqLastResortPrekeys != null) {
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
} }
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -17,29 +18,45 @@ import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record PhoneNumberIdentityKeyDistributionRequest( public record PhoneNumberIdentityKeyDistributionRequest(
@NotBlank @NotBlank
@Schema(description="the new identity key for this account's phone-number identity") @Schema(description="the new identity key for this account's phone-number identity")
String pniIdentityKey, String pniIdentityKey,
@NotNull @NotNull
@Valid @Valid
@Schema(description="A message for each companion device to pass its new private keys") @Schema(description="""
List<@NotNull @Valid IncomingMessage> deviceMessages, A list of synchronization messages to send to companion devices to supply the private keys
associated with the new identity key and their new prekeys.
Exactly one message must be supplied for each enabled device other than the sending (primary) device.""")
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @NotNull
@Valid @Valid
@Schema(description="The public key of a new signed elliptic-curve prekey pair for each device") @Schema(description="""
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys, A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull @Schema(description="""
@Valid A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
@Schema(description="The new registration ID to use for the phone-number identity of each device") May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
Map<Long, Integer> pniRegistrationIds) { If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniPqLastResortPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device")
Map<Long, Integer> pniRegistrationIds) {
@AssertTrue @AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() { public boolean isSignatureValidOnEachSignedPreKey() {
return devicePniSignedPrekeys.values().parallelStream() List<SignedPreKey> spks = new ArrayList<>(devicePniSignedPrekeys.values());
.allMatch(spk -> PreKeySignatureValidator.validatePreKeySignature(pniIdentityKey, spk)); if (devicePniPqLastResortPrekeys != null) {
spks.addAll(devicePniPqLastResortPrekeys.values());
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(pniIdentityKey, spks);
} }
} }

View File

@ -13,17 +13,17 @@ public class PreKey {
@JsonProperty @JsonProperty
@NotNull @NotNull
private long keyId; private long keyId;
@JsonProperty @JsonProperty
@NotEmpty @NotEmpty
private String publicKey; private String publicKey;
public PreKey() {} public PreKey() {}
public PreKey(long keyId, String publicKey) public PreKey(long keyId, String publicKey)
{ {
this.keyId = keyId; this.keyId = keyId;
this.publicKey = publicKey; this.publicKey = publicKey;
} }
@ -63,5 +63,4 @@ public class PreKey {
return ((int)this.keyId) ^ publicKey.hashCode(); return ((int)this.keyId) ^ publicKey.hashCode();
} }
} }
} }

View File

@ -5,16 +5,22 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyCount { public class PreKeyCount {
@Schema(description="the number of stored unsigned elliptic-curve prekeys for this device")
@JsonProperty @JsonProperty
private int count; private int count;
public PreKeyCount(int count) { @Schema(description="the number of stored one-time post-quantum prekeys for this device")
this.count = count; @JsonProperty
private int pqCount;
public PreKeyCount(int ecCount, int pqCount) {
this.count = ecCount;
this.pqCount = pqCount;
} }
public PreKeyCount() {} public PreKeyCount() {}
@ -22,4 +28,8 @@ public class PreKeyCount {
public int getCount() { public int getCount() {
return count; return count;
} }
public int getPqCount() {
return pqCount;
}
} }

View File

@ -7,15 +7,18 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List; import java.util.List;
public class PreKeyResponse { public class PreKeyResponse {
@JsonProperty @JsonProperty
@Schema(description="the public identity key for the requested identity")
private String identityKey; private String identityKey;
@JsonProperty @JsonProperty
@Schema(description="information about each requested device")
private List<PreKeyResponseItem> devices; private List<PreKeyResponseItem> devices;
public PreKeyResponse() {} public PreKeyResponse() {}

View File

@ -6,28 +6,39 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
public class PreKeyResponseItem { public class PreKeyResponseItem {
@JsonProperty @JsonProperty
@Schema(description="the device ID of the device to which this item pertains")
private long deviceId; private long deviceId;
@JsonProperty @JsonProperty
@Schema(description="the registration ID for the device")
private int registrationId; private int registrationId;
@JsonProperty @JsonProperty
@Schema(description="the signed elliptic-curve prekey for the device, if one has been set")
private SignedPreKey signedPreKey; private SignedPreKey signedPreKey;
@JsonProperty @JsonProperty
@Schema(description="an unsigned elliptic-curve prekey for the device, if any remain")
private PreKey preKey; private PreKey preKey;
@JsonProperty
@Schema(description="a signed post-quantum prekey for the device " +
"(a one-time prekey if any remain, otherwise the last-resort prekey if one has been set)")
private SignedPreKey pqPreKey;
public PreKeyResponseItem() {} public PreKeyResponseItem() {}
public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey) { public PreKeyResponseItem(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKey preKey, SignedPreKey pqPreKey) {
this.deviceId = deviceId; this.deviceId = deviceId;
this.registrationId = registrationId; this.registrationId = registrationId;
this.signedPreKey = signedPreKey; this.signedPreKey = signedPreKey;
this.preKey = preKey; this.preKey = preKey;
this.pqPreKey = pqPreKey;
} }
@VisibleForTesting @VisibleForTesting
@ -40,6 +51,11 @@ public class PreKeyResponseItem {
return preKey; return preKey;
} }
@VisibleForTesting
public SignedPreKey getPqPreKey() {
return pqPreKey;
}
@VisibleForTesting @VisibleForTesting
public int getRegistrationId() { public int getRegistrationId() {
return registrationId; return registrationId;

View File

@ -5,24 +5,38 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Base64; import java.util.Base64;
import java.util.Collection;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
public abstract class PreKeySignatureValidator { public abstract class PreKeySignatureValidator {
public static final boolean validatePreKeySignature(final String identityKeyB64, final SignedPreKey spk) { public static final Counter INVALID_SIGNATURE_COUNTER =
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature"));
public static final boolean validatePreKeySignatures(final String identityKeyB64, final Collection<SignedPreKey> spks) {
try { try {
final byte[] identityKeyBytes = Base64.getDecoder().decode(identityKeyB64); final byte[] identityKeyBytes = Base64.getDecoder().decode(identityKeyB64);
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0); final ECPublicKey identityKey = Curve.decodePoint(identityKeyBytes, 0);
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes); final boolean success = spks.stream().allMatch(spk -> {
final byte[] prekeyBytes = Base64.getDecoder().decode(spk.getPublicKey());
final byte[] prekeySignatureBytes = Base64.getDecoder().decode(spk.getSignature());
return identityKey.verifySignature(prekeyBytes, prekeySignatureBytes);
});
if (!success) {
INVALID_SIGNATURE_COUNTER.increment();
}
return success;
} catch (IllegalArgumentException | InvalidKeyException e) { } catch (IllegalArgumentException | InvalidKeyException e) {
Metrics.counter(name(PreKeySignatureValidator.class, "invalidPreKeySignature")).increment(); INVALID_SIGNATURE_COUNTER.increment();
return false; return false;
} }
} }

View File

@ -6,6 +6,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue; import javax.validation.constraints.AssertTrue;
@ -15,26 +17,59 @@ import javax.validation.constraints.NotNull;
public class PreKeyState { public class PreKeyState {
@JsonProperty @JsonProperty
@NotNull
@Valid @Valid
@Schema(description="A list of unsigned elliptic-curve prekeys to use for this device. " +
"If present and not empty, replaces all stored unsigned EC prekeys for the device; " +
"if absent or empty, any stored unsigned EC prekeys for the device are not deleted.")
private List<PreKey> preKeys; private List<PreKey> preKeys;
@JsonProperty @JsonProperty
@NotNull
@Valid @Valid
@Schema(description="An optional signed elliptic-curve prekey to use for this device. " +
"If present, replaces the stored signed elliptic-curve prekey for the device; " +
"if absent, the stored signed prekey is not deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey signedPreKey; private SignedPreKey signedPreKey;
@JsonProperty
@Valid
@Schema(description="A list of signed post-quantum one-time prekeys to use for this device. " +
"Each key must have a valid signature from the identity key in this request. " +
"If present and not empty, replaces all stored unsigned PQ prekeys for the device; " +
"if absent or empty, any stored unsigned PQ prekeys for the device are not deleted.")
private List<SignedPreKey> pqPreKeys;
@JsonProperty
@Valid
@Schema(description="An optional signed last-resort post-quantum prekey to use for this device. " +
"If present, replaces the stored signed post-quantum last-resort prekey for the device; " +
"if absent, a stored last-resort prekey will *not* be deleted. " +
"If present, must have a valid signature from the identity key in this request.")
private SignedPreKey pqLastResortPreKey;
@JsonProperty @JsonProperty
@NotEmpty @NotEmpty
@NotNull
@Schema(description="Required. " +
"The public identity key for this identity (account or phone-number identity). " +
"If this device is not the primary device for the account, " +
"must match the existing stored identity key for this identity.")
private String identityKey; private String identityKey;
public PreKeyState() {} public PreKeyState() {}
@VisibleForTesting @VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys) { public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys) {
this.identityKey = identityKey; this(identityKey, signedPreKey, keys, null, null);
this.signedPreKey = signedPreKey; }
this.preKeys = keys;
@VisibleForTesting
public PreKeyState(String identityKey, SignedPreKey signedPreKey, List<PreKey> keys, List<SignedPreKey> pqKeys, SignedPreKey pqLastResortKey) {
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this.pqPreKeys = pqKeys;
this.pqLastResortPreKey = pqLastResortKey;
} }
public List<PreKey> getPreKeys() { public List<PreKey> getPreKeys() {
@ -45,12 +80,30 @@ public class PreKeyState {
return signedPreKey; return signedPreKey;
} }
public List<SignedPreKey> getPqPreKeys() {
return pqPreKeys;
}
public SignedPreKey getPqLastResortPreKey() {
return pqLastResortPreKey;
}
public String getIdentityKey() { public String getIdentityKey() {
return identityKey; return identityKey;
} }
@AssertTrue @AssertTrue
public boolean isSignatureValid() { public boolean isSignatureValidOnEachSignedKey() {
return PreKeySignatureValidator.validatePreKeySignature(identityKey, signedPreKey); List<SignedPreKey> spks = new ArrayList<>();
if (pqPreKeys != null) {
spks.addAll(pqPreKeys);
}
if (pqLastResortPreKey != null) {
spks.add(pqLastResortPreKey);
}
if (signedPreKey != null) {
spks.add(signedPreKey);
}
return spks.isEmpty() || PreKeySignatureValidator.validatePreKeySignatures(identityKey, spks);
} }
} }

View File

@ -45,5 +45,4 @@ public class SignedPreKey extends PreKey {
return super.hashCode() ^ signature.hashCode(); return super.hashCode() ^ signature.hashCode();
} }
} }
} }

View File

@ -12,6 +12,7 @@ import static io.micrometer.core.instrument.Metrics.timer;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -53,7 +54,7 @@ public abstract class AbstractDynamoDbStore {
return dynamoDbClient; return dynamoDbClient;
} }
protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) { protected void executeTableWriteItemsUntilComplete(final Map<String, ? extends Collection<WriteRequest>> items) {
final AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>(); final AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome); writeAndStoreOutcome(items, batchWriteItemsFirstPass, outcome);
int attemptCount = 0; int attemptCount = 0;
@ -80,7 +81,7 @@ public abstract class AbstractDynamoDbStore {
} }
private void writeAndStoreOutcome( private void writeAndStoreOutcome(
final Map<String, List<WriteRequest>> items, final Map<String, ? extends Collection<WriteRequest>> items,
final Timer timer, final Timer timer,
final AtomicReference<BatchWriteItemResponse> outcome) { final AtomicReference<BatchWriteItemResponse> outcome) {
timer.record( timer.record(

View File

@ -245,6 +245,7 @@ public class AccountsManager {
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey, @Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { @Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber(); final String originalNumber = account.getNumber();
@ -252,12 +253,12 @@ public class AccountsManager {
if (originalNumber.equals(number)) { if (originalNumber.equals(number)) {
if (pniIdentityKey != null) { if (pniIdentityKey != null) {
throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePNIKeys"); throw new IllegalArgumentException("change number must supply a changed phone number; otherwise use updatePniKeys");
} }
return account; return account;
} }
validateDevices(account, pniSignedPreKeys, pniRegistrationIds); validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final AtomicReference<Account> updatedAccount = new AtomicReference<>(); final AtomicReference<Account> updatedAccount = new AtomicReference<>();
@ -281,7 +282,7 @@ public class AccountsManager {
numberChangedAccount = updateWithRetries( numberChangedAccount = updateWithRetries(
account, account,
a -> setPNIKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), a -> { setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); return true; },
a -> accounts.changeNumber(a, number, phoneNumberIdentifier), a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(), () -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
@ -291,45 +292,74 @@ public class AccountsManager {
keys.delete(phoneNumberIdentifier); keys.delete(phoneNumberIdentifier);
keys.delete(originalPhoneNumberIdentifier); keys.delete(originalPhoneNumberIdentifier);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(
phoneNumberIdentifier,
keys.getPqEnabledDevices(uuid).stream().collect(
Collectors.toMap(
Function.identity(),
pniPqLastResortPreKeys::get)));
}
return displacedUuid; return displacedUuid;
}); });
return updatedAccount.get(); return updatedAccount.get();
} }
public Account updatePNIKeys(final Account account, public Account updatePniKeys(final Account account,
final String pniIdentityKey, final String pniIdentityKey,
final Map<Long, SignedPreKey> pniSignedPreKeys, final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException { final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniRegistrationIds); validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
return update(account, a -> { return setPNIKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); final UUID pni = account.getPhoneNumberIdentifier();
final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); });
final List<Long> pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni);
keys.delete(pni);
if (pniPqLastResortPreKeys != null) {
keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)));
}
return updatedAccount;
} }
private boolean setPNIKeys(final Account account, private boolean setPniKeys(final Account account,
@Nullable final String pniIdentityKey, @Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) { @Nullable final Map<Long, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return true; return false;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null");
} }
pniSignedPreKeys.forEach((deviceId, signedPreKey) -> boolean changed = !pniIdentityKey.equals(account.getPhoneNumberIdentityKey());
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
for (Device device : account.getDevices()) {
if (!device.isEnabled()) {
continue;
}
SignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId());
changed = changed ||
!signedPreKey.equals(device.getPhoneNumberIdentitySignedPreKey()) ||
device.getRegistrationId() != registrationId;
device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
device.setPhoneNumberIdentityRegistrationId(registrationId);
}
pniRegistrationIds.forEach((deviceId, registrationId) -> account.setPhoneNumberIdentityKey(pniIdentityKey);
account.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
account.setPhoneNumberIdentityKey(pniIdentityKey); return changed;
return true;
} }
private void validateDevices(final Account account, private void validateDevices(final Account account,
final Map<Long, SignedPreKey> pniSignedPreKeys, @Nullable final Map<Long, SignedPreKey> pniSignedPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException { @Nullable final Map<Long, SignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) { if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return; return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) { } else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
@ -342,6 +372,12 @@ public class AccountsManager {
pniSignedPreKeys.keySet(), pniSignedPreKeys.keySet(),
Collections.emptySet()); Collections.emptySet());
// Check that all including master ID are in Pq pre-keys
DestinationDeviceValidator.validateCompleteDeviceList(
account,
pniSignedPreKeys.keySet(),
Collections.emptySet());
// Check that all devices are accounted for in the map of new PNI registration IDs // Check that all devices are accounted for in the map of new PNI registration IDs
DestinationDeviceValidator.validateCompleteDeviceList( DestinationDeviceValidator.validateCompleteDeviceList(
account, account,

View File

@ -42,6 +42,7 @@ public class ChangeNumberManager {
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@Nullable final String pniIdentityKey, @Nullable final String pniIdentityKey,
@Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys, @Nullable final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages, @Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds) @Nullable final Map<Long, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException { throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
@ -62,10 +63,14 @@ public class ChangeNumberManager {
// We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key // We don't need to actually do a number-change operation in our DB, but we *do* need to accept their new key
// material and distribute the sync messages, to be sure all clients agree with us and each other about what their // material and distribute the sync messages, to be sure all clients agree with us and each other about what their
// keys are. Pretend this change-number request was actually a PNI key distribution request. // keys are. Pretend this change-number request was actually a PNI key distribution request.
return updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds); if (pniIdentityKey == null) {
return account;
}
return updatePniKeys(account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, deviceMessages, pniRegistrationIds);
} }
final Account updatedAccount = accountsManager.changeNumber(account, number, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); final Account updatedAccount = accountsManager.changeNumber(
account, number, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
if (deviceMessages != null) { if (deviceMessages != null) {
sendDeviceMessages(updatedAccount, deviceMessages); sendDeviceMessages(updatedAccount, deviceMessages);
@ -74,16 +79,18 @@ public class ChangeNumberManager {
return updatedAccount; return updatedAccount;
} }
public Account updatePNIKeys(final Account account, public Account updatePniKeys(final Account account,
final String pniIdentityKey, final String pniIdentityKey,
final Map<Long, SignedPreKey> deviceSignedPreKeys, final Map<Long, SignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, SignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages, final List<IncomingMessage> deviceMessages,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
validateDeviceMessages(account, deviceMessages); validateDeviceMessages(account, deviceMessages);
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb
// write anyway. Linked devices can handle some wasted extra key rotations. // write anyway. Linked devices can handle some wasted extra key rotations.
final Account updatedAccount = accountsManager.updatePNIKeys(account, pniIdentityKey, deviceSignedPreKeys, pniRegistrationIds); final Account updatedAccount = accountsManager.updatePniKeys(
account, pniIdentityKey, deviceSignedPreKeys, devicePqLastResortPreKeys, pniRegistrationIds);
sendDeviceMessages(updatedAccount, deviceMessages); sendDeviceMessages(updatedAccount, deviceMessages);
return updatedAccount; return updatedAccount;

View File

@ -6,6 +6,9 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Multimaps;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
@ -16,7 +19,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -34,11 +41,14 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class Keys extends AbstractDynamoDbStore { public class Keys extends AbstractDynamoDbStore {
private final String tableName; private final String ecTableName;
private final String pqTableName;
private final String pqLastResortTableName;
static final String KEY_ACCOUNT_UUID = "U"; static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK"; static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String KEY_PUBLIC_KEY = "P"; static final String KEY_PUBLIC_KEY = "P";
static final String KEY_SIGNATURE = "S";
private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys")); private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys"));
private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice")); private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice"));
@ -48,31 +58,114 @@ public class Keys extends AbstractDynamoDbStore {
private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys")); private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys"));
private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount")); private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount"));
private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty")); private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty"));
private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys"));
public Keys(final DynamoDbClient dynamoDB, final String tableName) { public Keys(
final DynamoDbClient dynamoDB,
final String ecTableName,
final String pqTableName,
final String pqLastResortTableName) {
super(dynamoDB); super(dynamoDB);
this.tableName = tableName; this.ecTableName = ecTableName;
this.pqTableName = pqTableName;
this.pqLastResortTableName = pqLastResortTableName;
} }
public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) { public void store(final UUID identifier, final long deviceId, final List<PreKey> keys) {
STORE_KEYS_TIMER.record(() -> { store(identifier, deviceId, keys, null, null);
delete(identifier, deviceId); }
writeInBatches(keys, batch -> { public void store(
List<WriteRequest> items = new ArrayList<>(); final UUID identifier, final long deviceId,
for (final PreKey preKey : batch) { @Nullable final List<PreKey> ecKeys,
items.add(WriteRequest.builder() @Nullable final List<SignedPreKey> pqKeys,
.putRequest(PutRequest.builder() @Nullable final SignedPreKey pqLastResortKey) {
.item(getItemFromPreKey(identifier, deviceId, preKey)) Multimap<String, PreKey> keys = MultimapBuilder.hashKeys().arrayListValues().build();
.build()) List<String> tablesToClear = new ArrayList<>();
.build());
} if (ecKeys != null && !ecKeys.isEmpty()) {
executeTableWriteItemsUntilComplete(Map.of(tableName, items)); keys.putAll(ecTableName, ecKeys);
tablesToClear.add(ecTableName);
}
if (pqKeys != null && !pqKeys.isEmpty()) {
keys.putAll(pqTableName, pqKeys);
tablesToClear.add(pqTableName);
}
if (pqLastResortKey != null) {
keys.put(pqLastResortTableName, pqLastResortKey);
tablesToClear.add(pqLastResortTableName);
}
STORE_KEYS_TIMER.record(() -> {
delete(tablesToClear, identifier, deviceId);
writeInBatches(
keys.entries(),
batch -> {
Multimap<String, WriteRequest> writes = batch.stream()
.collect(
Multimaps.toMultimap(
Map.Entry<String, PreKey>::getKey,
entry -> WriteRequest.builder()
.putRequest(PutRequest.builder()
.item(getItemFromPreKey(identifier, deviceId, entry.getValue()))
.build())
.build(),
MultimapBuilder.hashKeys().arrayListValues()::build));
executeTableWriteItemsUntilComplete(writes.asMap());
}); });
}); });
} }
public Optional<PreKey> take(final UUID identifier, final long deviceId) { public void storePqLastResort(final UUID identifier, final Map<Long, SignedPreKey> keys) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true)
.build();
final List<WriteRequest> writes = new ArrayList<>(2 * keys.size());
final Map<Long, Map<String, AttributeValue>> newItems = keys.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue())));
for (final Map<String, AttributeValue> item : db().query(queryRequest).items()) {
final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID);
final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong();
if (newItems.containsKey(oldDeviceId)) {
final Map<String, AttributeValue> replacement = newItems.get(oldDeviceId);
if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) {
writes.add(WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, oldSortKey))
.build())
.build());
}
}
}
newItems.forEach((unusedKey, item) ->
writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build()));
executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes));
}
public Optional<PreKey> takeEC(final UUID identifier, final long deviceId) {
return take(ecTableName, identifier, deviceId);
}
public Optional<SignedPreKey> takePQ(final UUID identifier, final long deviceId) {
return take(pqTableName, identifier, deviceId)
.or(() -> getLastResort(identifier, deviceId))
.map(pk -> (SignedPreKey) pk);
}
private Optional<PreKey> take(final String tableName, final UUID identifier, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final AttributeValue partitionKey = getPartitionKey(identifier); final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder() QueryRequest queryRequest = QueryRequest.builder()
@ -114,7 +207,53 @@ public class Keys extends AbstractDynamoDbStore {
}); });
} }
public int getCount(final UUID identifier, final long deviceId) { @VisibleForTesting
Optional<PreKey> getLastResort(final UUID identifier, final long deviceId) {
final AttributeValue partitionKey = getPartitionKey(identifier);
QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of(
":uuid", partitionKey,
":sortprefix", getSortKeyPrefix(deviceId)))
.consistentRead(false)
.select(Select.ALL_ATTRIBUTES)
.build();
QueryResponse response = db().query(queryRequest);
if (response.count() > 1) {
TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment();
}
return response.items().stream().findFirst().map(this::getPreKeyFromItem);
}
public List<Long> getPqEnabledDevices(final UUID identifier) {
final AttributeValue partitionKey = getPartitionKey(identifier);
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(pqLastResortTableName)
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", partitionKey))
.projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(false)
.build();
final QueryResponse response = db().query(queryRequest);
return response.items().stream()
.map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong())
.toList();
}
public int getEcCount(final UUID identifier, final long deviceId) {
return getCount(ecTableName, identifier, deviceId);
}
public int getPqCount(final UUID identifier, final long deviceId) {
return getCount(pqTableName, identifier, deviceId);
}
private int getCount(final String tableName, final UUID identifier, final long deviceId) {
return GET_KEY_COUNT_TIMER.record(() -> { return GET_KEY_COUNT_TIMER.record(() -> {
QueryRequest queryRequest = QueryRequest.builder() QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName) .tableName(tableName)
@ -144,51 +283,66 @@ public class Keys extends AbstractDynamoDbStore {
public void delete(final UUID accountUuid) { public void delete(final UUID accountUuid) {
DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder() final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid") .keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid))) ":uuid", getPartitionKey(accountUuid)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID) .projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true) .consistentRead(true)
.build(); .build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest); deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest);
}); });
} }
public void delete(final UUID accountUuid, final long deviceId) { public void delete(final UUID accountUuid, final long deviceId) {
delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId);
}
private void delete(final List<String> tableNames, final UUID accountUuid, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder() final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":uuid", getPartitionKey(accountUuid), ":uuid", getPartitionKey(accountUuid),
":sortprefix", getSortKeyPrefix(deviceId))) ":sortprefix", getSortKeyPrefix(deviceId)))
.projectionExpression(KEY_DEVICE_ID_KEY_ID) .projectionExpression(KEY_DEVICE_ID_KEY_ID)
.consistentRead(true) .consistentRead(true)
.build(); .build();
deleteItemsForAccountMatchingQuery(accountUuid, queryRequest); deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest);
}); });
} }
private void deleteItemsForAccountMatchingQuery(final UUID accountUuid, final QueryRequest querySpec) { private void deleteItemsForAccountMatchingQuery(final List<String> tableNames, final UUID accountUuid, final QueryRequest querySpec) {
final AttributeValue partitionKey = getPartitionKey(accountUuid); final AttributeValue partitionKey = getPartitionKey(accountUuid);
writeInBatches(db().query(querySpec).items(), batch -> { Multimap<String, Map<String, AttributeValue>> itemStream = tableNames.stream()
List<WriteRequest> deletes = new ArrayList<>(); .collect(
for (final Map<String, AttributeValue> item : batch) { Multimaps.flatteningToMultimap(
deletes.add(WriteRequest.builder() Function.identity(),
.deleteRequest(DeleteRequest.builder() tableName ->
.key(Map.of( db().query(querySpec.toBuilder().tableName(tableName).build())
KEY_ACCOUNT_UUID, partitionKey, .items()
KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID))) .stream(),
.build()) MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
.build());
} writeInBatches(
executeTableWriteItemsUntilComplete(Map.of(tableName, deletes)); itemStream.entries(),
batch -> {
Multimap<String, WriteRequest> deletes = batch.stream()
.collect(Multimaps.toMultimap(
Map.Entry<String, Map<String, AttributeValue>>::getKey,
entry -> WriteRequest.builder()
.deleteRequest(DeleteRequest.builder()
.key(Map.of(
KEY_ACCOUNT_UUID, partitionKey,
KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID)))
.build())
.build(),
MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build));
executeTableWriteItemsUntilComplete(deletes.asMap());
}); });
} }
@ -211,6 +365,13 @@ public class Keys extends AbstractDynamoDbStore {
} }
private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { private Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) {
if (preKey instanceof final SignedPreKey spk) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()),
KEY_PUBLIC_KEY, AttributeValues.fromString(spk.getPublicKey()),
KEY_SIGNATURE, AttributeValues.fromString(spk.getSignature()));
}
return Map.of( return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()),
@ -219,6 +380,11 @@ public class Keys extends AbstractDynamoDbStore {
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) { private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s());
}
return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s()); return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s());
} }
} }

View File

@ -174,7 +174,9 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName()); configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName()); configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(), configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -154,7 +154,9 @@ record CommandDependencies(
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getProfiles().getTableName()); configuration.getDynamoDbTables().getProfiles().getTableName());
Keys keys = new Keys(dynamoDbClient, Keys keys = new Keys(dynamoDbClient,
configuration.getDynamoDbTables().getKeys().getTableName()); configuration.getDynamoDbTables().getEcKeys().getTableName(),
configuration.getDynamoDbTables().getPqKeys().getTableName(),
configuration.getDynamoDbTables().getPqLastResortKeys().getTableName());
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getTableName(),
configuration.getDynamoDbTables().getMessages().getExpiration(), configuration.getDynamoDbTables().getMessages().getExpiration(),

View File

@ -334,7 +334,7 @@ class AccountControllerTest {
return account; return account;
}); });
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> { when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class); final String number = invocation.getArgument(1, String.class);
final String pniIdentityKey = invocation.getArgument(2, String.class); final String pniIdentityKey = invocation.getArgument(2, String.class);
@ -358,7 +358,7 @@ class AccountControllerTest {
return updatedAccount; return updatedAccount;
}); });
when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> { when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final String pniIdentityKey = invocation.getArgument(1, String.class); final String pniIdentityKey = invocation.getArgument(1, String.class);
@ -1377,12 +1377,12 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT); verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any());
assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.number()).isEqualTo(number); assertThat(accountIdentityResponse.number()).isEqualTo(number);
@ -1399,12 +1399,12 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.readEntity(String.class)).isBlank(); assertThat(response.readEntity(String.class)).isBlank();
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1417,7 +1417,7 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getStatus()).isEqualTo(400);
@ -1426,7 +1426,7 @@ class AccountControllerTest {
assertThat(responseEntity.getOriginalNumber()).isEqualTo(number); assertThat(responseEntity.getOriginalNumber()).isEqualTo(number);
assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111"); assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111");
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1436,10 +1436,10 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1454,11 +1454,11 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1478,13 +1478,13 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT); verify(registrationServiceClient).checkVerificationCode(sessionId, code, AccountController.REGISTRATION_RPC_TIMEOUT);
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1514,11 +1514,11 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1549,14 +1549,14 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423); assertThat(response.getStatus()).isEqualTo(423);
// verify(existingAccount).lockAuthenticationCredentials(); // verify(existingAccount).lockAuthenticationCredentials();
// verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any()); // verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1589,14 +1589,14 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423); assertThat(response.getStatus()).isEqualTo(423);
// verify(existingAccount).lockAuthenticationCredentials(); // verify(existingAccount).lockAuthenticationCredentials();
// verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any()); // verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1628,13 +1628,13 @@ class AccountControllerTest {
.target("/v1/accounts/number") .target("/v1/accounts/number")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null), .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
verify(senderRegLockAccount, never()).lockAuthTokenHash(); verify(senderRegLockAccount, never()).lockAuthTokenHash();
verify(clientPresenceManager, never()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any()); verify(clientPresenceManager, never()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any(), any());
} }
@Test @Test
@ -1681,10 +1681,11 @@ class AccountControllerTest {
number, code, null, number, code, null,
pniIdentityKey, deviceMessages, pniIdentityKey, deviceMessages,
deviceKeys, deviceKeys,
null,
registrationIds), registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(number), any(), any(), any(), any(), any());
assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID); assertThat(accountIdentityResponse.uuid()).isEqualTo(AuthHelper.VALID_UUID);
assertThat(accountIdentityResponse.number()).isEqualTo(number); assertThat(accountIdentityResponse.number()).isEqualTo(number);
@ -1734,11 +1735,12 @@ class AccountControllerTest {
AuthHelper.VALID_NUMBER, code, null, AuthHelper.VALID_NUMBER, code, null,
pniIdentityKey, deviceMessages, pniIdentityKey, deviceMessages,
deviceKeys, deviceKeys,
null,
registrationIds), registrationIds),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber( verify(changeNumberManager).changeNumber(
eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any()); eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), any(), any());
verifyNoInteractions(rateLimiter); verifyNoInteractions(rateLimiter);
verifyNoInteractions(pendingAccountsManager); verifyNoInteractions(pendingAccountsManager);

View File

@ -134,7 +134,7 @@ class AccountControllerV2Test {
void setUp() throws Exception { void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer( when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> { (Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class); final String number = invocation.getArgument(1, String.class);
@ -180,11 +180,11 @@ class AccountControllerV2Test {
.put(Entity.entity( .put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123", new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123",
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), Collections.emptyMap()), Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any()); any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number()); assertEquals(NEW_NUMBER, accountIdentityResponse.number());
@ -203,11 +203,11 @@ class AccountControllerV2Test {
new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null, new ChangeNumberRequest(encodeSessionId("session"), null, AuthHelper.VALID_NUMBER, null,
"pni-identity-key", "pni-identity-key",
Collections.emptyList(), Collections.emptyList(),
Collections.emptyMap(), Collections.emptyMap()), Collections.emptyMap(), null, Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(), verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_NUMBER), any(), any(), any(),
any()); any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number());
@ -365,7 +365,7 @@ class AccountControllerV2Test {
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class); final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any()); any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number()); assertEquals(NEW_NUMBER, accountIdentityResponse.number());
@ -458,7 +458,7 @@ class AccountControllerV2Test {
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
when(changeNumberManager.updatePNIKeys(any(), any(), any(), any(), any())).thenAnswer( when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> { (Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final String pniIdentityKey = invocation.getArgument(1, String.class); final String pniIdentityKey = invocation.getArgument(1, String.class);
@ -496,7 +496,7 @@ class AccountControllerV2Test {
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(requestJson()), AccountIdentityResponse.class); .put(Entity.json(requestJson()), AccountIdentityResponse.class);
verify(changeNumberManager).updatePNIKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any()); verify(changeNumberManager).updatePniKeys(eq(AuthHelper.VALID_ACCOUNT), eq("pni-identity-key"), any(), any(), any(), any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number()); assertEquals(AuthHelper.VALID_NUMBER, accountIdentityResponse.number());
@ -557,6 +557,7 @@ class AccountControllerV2Test {
"pniIdentityKey": "pni-identity-key", "pniIdentityKey": "pni-identity-key",
"deviceMessages": [], "deviceMessages": [],
"devicePniSignedPrekeys": {}, "devicePniSignedPrekeys": {},
"devicePniSignedPqPrekeys": {},
"pniRegistrationIds": {} "pniRegistrationIds": {}
} }
"""; """;

View File

@ -128,7 +128,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
accountsManager.changeNumber(account, secondNumber, null, null, null); accountsManager.changeNumber(account, secondNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -161,7 +161,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey); final Map<Long, SignedPreKey> preKeys = Map.of(Device.MASTER_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId); final Map<Long, Integer> registrationIds = Map.of(Device.MASTER_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, registrationIds); final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -191,8 +191,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
account = accountsManager.changeNumber(account, secondNumber, null, null, null); account = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
accountsManager.changeNumber(account, originalNumber, null, null, null); accountsManager.changeNumber(account, originalNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isPresent()); assertTrue(accountsManager.getByE164(originalNumber).isPresent());
assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow()); assertEquals(originalUuid, accountsManager.getByE164(originalNumber).map(Account::getUuid).orElseThrow());
@ -217,7 +217,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber, null, null, null); accountsManager.changeNumber(account, secondNumber, null, null, null, null);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());
@ -231,7 +231,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null); accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null);
final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(),
new ArrayList<>()); new ArrayList<>());
@ -251,7 +251,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null); final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
@ -262,7 +262,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));
final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null); final Account changedNumberReRegisteredAccount = accountsManager.changeNumber(reRegisteredAccount, secondNumber, null, null, null, null);
assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.of(originalUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));

View File

@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -641,7 +642,7 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber, null, null, null); account = accountsManager.changeNumber(account, targetNumber, null, null, null, null);
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
@ -656,7 +657,7 @@ class AccountsManagerTest {
final String number = "+14152222222"; final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, number, null, null, null); account = accountsManager.changeNumber(account, number, null, null, null, null);
assertEquals(number, account.getNumber()); assertEquals(number, account.getNumber());
verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any()); verify(deletedAccountsManager, never()).lockAndPut(anyString(), anyString(), any());
@ -664,13 +665,13 @@ class AccountsManagerTest {
} }
@Test @Test
void testChangePhoneNumberSameNumberWithPNIData() { void testChangePhoneNumberSameNumberWithPniData() {
final String number = "+14152222222"; final String number = "+14152222222";
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber( () -> accountsManager.changeNumber(
account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), Map.of(1L, 101)), account, number, "new-identity-key", Map.of(1L, new SignedPreKey()), null, Map.of(1L, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any()); verify(accounts, never()).update(any());
@ -694,14 +695,60 @@ class AccountsManagerTest {
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber, null, null, null); account = accountsManager.changeNumber(account, targetNumber, null, null, null, null);
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys).delete(originalPni); verify(keys).delete(originalPni);
verify(keys).delete(targetPni); verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verifyNoMoreInteractions(keys);
}
@Test
void testChangePhoneNumberWithPqKeysExistingAccount() throws InterruptedException, MismatchedDevicesException {
doAnswer(invocation -> invocation.getArgument(2, BiFunction.class).apply(Optional.empty(), Optional.empty()))
.when(deletedAccountsManager).lockAndPut(anyString(), anyString(), any());
final String originalNumber = "+14152222222";
final String targetNumber = "+14153333333";
final UUID existingAccountUuid = UUID.randomUUID();
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Map<Long, SignedPreKey> newSignedKeys = Map.of(
1L, new SignedPreKey(1L, "pub1", "sig1"),
2L, new SignedPreKey(2L, "pub2", "sig2"));
final Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
1L, new SignedPreKey(3L, "pub3", "sig3"),
2L, new SignedPreKey(4L, "pub4", "sig4"));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]);
final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds);
assertEquals(targetNumber, updatedAccount.getNumber());
assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber));
final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber);
verify(keys).delete(existingAccountUuid);
verify(keys, atLeastOnce()).delete(targetPni);
verify(keys).delete(newPni);
verify(keys).delete(originalPni);
verify(keys).getPqEnabledDevices(uuid);
verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, new SignedPreKey(3L, "pub3", "sig3"))));
verifyNoMoreInteractions(keys);
} }
@Test @Test
@ -716,7 +763,7 @@ class AccountsManagerTest {
} }
@Test @Test
void testPNIUpdate() throws MismatchedDevicesException { void testPniUpdate() throws MismatchedDevicesException {
final String number = "+14152222222"; final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
@ -730,7 +777,7 @@ class AccountsManagerTest {
UUID oldPni = account.getPhoneNumberIdentifier(); UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final Account updatedAccount = accountsManager.updatePNIKeys(account, "new-pni-identity-key", newSignedKeys, newRegistrationIds); final Account updatedAccount = accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, null, newRegistrationIds);
// non-PNI stuff should not change // non-PNI stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid()); assertEquals(oldUuid, updatedAccount.getUuid());
@ -750,7 +797,57 @@ class AccountsManagerTest {
verify(accounts).update(any()); verify(accounts).update(any());
verifyNoInteractions(deletedAccountsManager); verifyNoInteractions(deletedAccountsManager);
verifyNoInteractions(keys);
verify(keys).delete(oldPni);
}
@Test
void testPniPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[16]);
Map<Long, SignedPreKey> newSignedKeys = Map.of(
1L, new SignedPreKey(1L, "pub1", "sig1"),
2L, new SignedPreKey(2L, "pub2", "sig2"));
Map<Long, SignedPreKey> newSignedPqKeys = Map.of(
1L, new SignedPreKey(3L, "pub3", "sig3"),
2L, new SignedPreKey(4L, "pub4", "sig4"));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L));
Map<Long, SignedPreKey> oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey));
final Account updatedAccount =
accountsManager.updatePniKeys(account, "new-pni-identity-key", newSignedKeys, newSignedPqKeys, newRegistrationIds);
// non-PNI-keys stuff should not change
assertEquals(oldUuid, updatedAccount.getUuid());
assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertEquals(null, updatedAccount.getIdentityKey());
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)));
assertEquals(Map.of(1L, 101, 2L, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
assertEquals("new-pni-identity-key", updatedAccount.getPhoneNumberIdentityKey());
assertEquals(newSignedKeys,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentitySignedPreKey)));
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
verify(accounts).update(any());
verifyNoInteractions(deletedAccountsManager);
verify(keys).delete(oldPni);
// only the pq key for the already-pq-enabled device should be saved
verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
} }
@Test @Test

View File

@ -47,7 +47,7 @@ public class ChangeNumberManagerTest {
updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
when(accountsManager.changeNumber(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> { when(accountsManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class); final String number = invocation.getArgument(1, String.class);
@ -70,7 +70,7 @@ public class ChangeNumberManagerTest {
return updatedAccount; return updatedAccount;
}); });
when(accountsManager.updatePNIKeys(any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> { when(accountsManager.updatePniKeys(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>)invocation -> {
final Account account = invocation.getArgument(0, Account.class); final Account account = invocation.getArgument(0, Account.class);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
@ -94,8 +94,8 @@ public class ChangeNumberManagerTest {
void changeNumberNoMessages() throws Exception { void changeNumberNoMessages() throws Exception {
Account account = mock(Account.class); Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), eq(1L), any()); verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
} }
@ -107,8 +107,8 @@ public class ChangeNumberManagerTest {
var prekeys = Map.of(1L, new SignedPreKey()); var prekeys = Map.of(1L, new SignedPreKey());
final String pniIdentityKey = "pni-identity-key"; final String pniIdentityKey = "pni-identity-key";
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyList(), Collections.emptyMap()); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
} }
@ -139,9 +139,53 @@ public class ChangeNumberManagerTest {
when(msg.destinationDeviceId()).thenReturn(2L); when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(Device.MASTER_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
}
@Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234";
final String changedE164 = "+18025551234";
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn(originalE164);
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -174,15 +218,16 @@ public class ChangeNumberManagerTest {
final String pniIdentityKey = "pni-identity-key"; final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey()); final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L); when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, List.of(msg), registrationIds); changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -196,7 +241,7 @@ public class ChangeNumberManagerTest {
} }
@Test @Test
void updatePNIKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception { void updatePniKeysSetPrimaryDevicePrekeyAndSendMessages() throws Exception {
final UUID aci = UUID.randomUUID(); final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID(); final UUID pni = UUID.randomUUID();
@ -219,9 +264,49 @@ public class ChangeNumberManagerTest {
when(msg.destinationDeviceId()).thenReturn(2L); when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePNIKeys(account, pniIdentityKey, prekeys, List.of(msg), registrationIds); changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
verify(accountsManager).updatePNIKeys(account, pniIdentityKey, prekeys, registrationIds); verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, null, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
final MessageProtos.Envelope envelope = envelopeCaptor.getValue();
assertEquals(aci, UUID.fromString(envelope.getDestinationUuid()));
assertEquals(aci, UUID.fromString(envelope.getSourceUuid()));
assertEquals(Device.MASTER_ID, envelope.getSourceDevice());
assertFalse(updatedPhoneNumberIdentifiersByAccount.containsKey(account));
}
@Test
void updatePniKeysSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final String pniIdentityKey = "pni-identity-key";
final Map<Long, SignedPreKey> prekeys = Map.of(1L, new SignedPreKey(), 2L, new SignedPreKey());
final Map<Long, SignedPreKey> pqPrekeys = Map.of(3L, new SignedPreKey(), 4L, new SignedPreKey());
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
verify(accountsManager).updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, registrationIds);
final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class); final ArgumentCaptor<MessageProtos.Envelope> envelopeCaptor = ArgumentCaptor.forClass(MessageProtos.Envelope.class);
verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false)); verify(messageSender).sendMessage(any(), eq(d2), envelopeCaptor.capture(), eq(false));
@ -261,11 +346,11 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, messages, registrationIds)); () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", preKeys, null, messages, registrationIds));
} }
@Test @Test
void updatePNIKeysMismatchedRegistrationId() { void updatePniKeysMismatchedRegistrationId() {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
@ -291,7 +376,7 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(StaleDevicesException.class, assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePNIKeys(account, "pni-identity-key", preKeys, messages, registrationIds)); () -> changeNumberManager.updatePniKeys(account, "pni-identity-key", preKeys, null, messages, registrationIds));
} }
@Test @Test
@ -320,6 +405,6 @@ public class ChangeNumberManagerTest {
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, messages, registrationIds)); () -> changeNumberManager.changeNumber(account, "+18005559876", "pni-identity-key", null, null, messages, registrationIds));
} }
} }

View File

@ -69,7 +69,35 @@ public final class DynamoDbExtensionSchema {
.build()), .build()),
List.of(), List.of()), List.of(), List.of()),
KEYS("keys_test", EC_KEYS("keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_KEYS("pq_keys_test",
Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID,
List.of(
AttributeDefinition.builder()
.attributeName(Keys.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Keys.KEY_DEVICE_ID_KEY_ID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test",
Keys.KEY_ACCOUNT_UUID, Keys.KEY_ACCOUNT_UUID,
Keys.KEY_DEVICE_ID_KEY_ID, Keys.KEY_DEVICE_ID_KEY_ID,
List.of( List.of(

View File

@ -6,99 +6,244 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
class KeysTest { class KeysTest {
private Keys keys; private Keys keys;
@RegisterExtension @RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.KEYS); static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L; private static final long DEVICE_ID = 1L;
@BeforeEach @BeforeEach
void setup() { void setup() {
keys = new Keys(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.KEYS.tableName()); keys = new Keys(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.PQ_LAST_RESORT_KEYS.tableName());
} }
@Test @Test
void testStore() { void testStore() {
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero"); "Initial pre-key count for an account should be zero");
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Initial pre-key count for an account should be zero");
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect"); "Repeatedly storing same key should have no effect");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1, "pq-public-key", "sig")), null);
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys for the given account/device"); "Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, new SignedPreKey(1001, "pq-last-resort-key", "sig"));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID), assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")), null, null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key")), List.of(new SignedPreKey(2, "different-pq-public-key", "sig")), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(4, "fourth-public-key"), new PreKey(5, "fifth-public-key")),
List.of(new SignedPreKey(6, "sixth-pq-key", "sig"), new SignedPreKey(7, "seventh-pq-key", "sig")),
new SignedPreKey(1002, "new-last-resort-key", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device"); "Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(),
"Uploading new last-resort key should overwrite prior last-resort key for the account/device");
} }
@Test @Test
void testTakeAccountAndDeviceId() { void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.take(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = new PreKey(1, "public-key"); final PreKey preKey = new PreKey(1, "public-key");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(Optional.of(preKey), keys.take(ACCOUNT_UUID, DEVICE_ID)); assertEquals(Optional.of(preKey), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@Test
void testTakePQ() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = new SignedPreKey(1, "public-key", "sig");
final SignedPreKey preKey2 = new SignedPreKey(2, "different-public-key", "sig");
final SignedPreKey preKeyLast = new SignedPreKey(1001, "last-public-key", "sig");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
} }
@Test @Test
void testGetCount() { void testGetCount() {
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")), List.of(new SignedPreKey(1, "public-pq-key", "sig")), null);
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
} }
@Test @Test
void testDeleteByAccount() { void testDeleteByAccount() {
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID,
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID); keys.delete(ACCOUNT_UUID);
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
} }
@Test @Test
void testDeleteByAccountAndDevice() { void testDeleteByAccountAndDevice() {
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); keys.store(ACCOUNT_UUID, DEVICE_ID,
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
keys.delete(ACCOUNT_UUID, DEVICE_ID); keys.delete(ACCOUNT_UUID, DEVICE_ID);
assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent());
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1));
assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent());
}
@Test
void testStorePqLastResort() {
assertEquals(0, getLastResortCount(ACCOUNT_UUID));
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, new SignedPreKey(1L, "pub1", "sig1"), 2L, new SignedPreKey(2L, "pub2", "sig2")));
assertEquals(2, getLastResortCount(ACCOUNT_UUID));
assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId());
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId());
assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent());
keys.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, new SignedPreKey(3L, "pub3", "sig3"), 3L, new SignedPreKey(4L, "pub4", "sig4")));
assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones");
}
private int getLastResortCount(UUID uuid) {
QueryRequest queryRequest = QueryRequest.builder()
.tableName(Tables.PQ_LAST_RESORT_KEYS.tableName())
.keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid)))
.select(Select.COUNT)
.build();
QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest);
return response.count();
}
@Test
void testGetPqEnabledDevices() {
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1L, "pub1", "sig1")), null);
keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, new SignedPreKey(2L, "pub2", "sig2"));
keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(new SignedPreKey(3L, "pub3", "sig3")), new SignedPreKey(4L, "pub4", "sig4"));
keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null);
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID)));
} }
@Test @Test

View File

@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
@ -86,19 +87,25 @@ class KeysControllerTest {
private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair(); private final ECKeyPair PNI_IDENTITY_KEY_PAIR = Curve.generateKeyPair();
private final String PNI_IDENTITY_KEY = KeysHelper.serializeIdentityKey(PNI_IDENTITY_KEY_PAIR); private final String PNI_IDENTITY_KEY = KeysHelper.serializeIdentityKey(PNI_IDENTITY_KEY_PAIR);
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1"); private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3"); private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5"); private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6"); private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6");
private final PreKey SAMPLE_KEY_PNI = new PreKey(7777, "test7"); private final PreKey SAMPLE_KEY_PNI = new PreKey(7777, "test7");
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey( 1111, IDENTITY_KEY_PAIR); private final SignedPreKey SAMPLE_PQ_KEY = new SignedPreKey(2424, "test1", "sig");
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey( 2222, IDENTITY_KEY_PAIR); private final SignedPreKey SAMPLE_PQ_KEY2 = new SignedPreKey(6868, "test3", "sig");
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey( 3333, IDENTITY_KEY_PAIR); private final SignedPreKey SAMPLE_PQ_KEY3 = new SignedPreKey(1313, "test5", "sig");
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey( 4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey( 5555, PNI_IDENTITY_KEY_PAIR); private final SignedPreKey SAMPLE_PQ_KEY_PNI = new SignedPreKey(8888, "test7", "sig");
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey( 6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY = KeysHelper.signedPreKey(1111, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY2 = KeysHelper.signedPreKey(2222, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_KEY3 = KeysHelper.signedPreKey(3333, IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY = KeysHelper.signedPreKey(4444, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY2 = KeysHelper.signedPreKey(5555, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey SAMPLE_SIGNED_PNI_KEY3 = KeysHelper.signedPreKey(6666, PNI_IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedPreKey(89898, IDENTITY_KEY_PAIR); private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedPreKey(89898, IDENTITY_KEY_PAIR);
private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedPreKey(7777, PNI_IDENTITY_KEY_PAIR); private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedPreKey(7777, PNI_IDENTITY_KEY_PAIR);
@ -177,10 +184,13 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI)); when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_KEY_PNI));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY_PNI));
when(KEYS.getCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getPhoneNumberIdentitySignedPreKey()).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@ -210,8 +220,10 @@ class KeysControllerTest {
.get(PreKeyCount.class); .get(PreKeyCount.class);
assertThat(result.getCount()).isEqualTo(5); assertThat(result.getCount()).isEqualTo(5);
assertThat(result.getPqCount()).isEqualTo(5);
verify(KEYS).getCount(AuthHelper.VALID_UUID, 1); verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1);
} }
@ -223,9 +235,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class); .get(SignedPreKey.class);
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getSignature()); assertKeysMatch(VALID_DEVICE_SIGNED_KEY, result);
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_SIGNED_KEY.getPublicKey());
} }
@Test @Test
@ -237,9 +247,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(SignedPreKey.class); .get(SignedPreKey.class);
assertThat(result.getSignature()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getSignature()); assertKeysMatch(VALID_DEVICE_PNI_SIGNED_KEY, result);
assertThat(result.getKeyId()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(VALID_DEVICE_PNI_SIGNED_KEY.getPublicKey());
} }
@Test @Test
@ -291,19 +299,63 @@ class KeysControllerTest {
@Test @Test
void validSingleRequestTestV2() { void validSingleRequestTestV2() {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.<SignedPreKey>empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -317,12 +369,33 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1); verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqByPhoneNumberIdentifierTestV2() {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_PNI))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1);
assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -338,12 +411,12 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getPhoneNumberIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY_PNI.getKeyId()); assertKeysMatch(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY_PNI.getPublicKey()); assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey()); assertKeysMatch(existsAccount.getDevice(1).get().getPhoneNumberIdentitySignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_PNI, 1); verify(KEYS).takeEC(EXISTS_PNI, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -365,18 +438,20 @@ class KeysControllerTest {
@Test @Test
void testUnidentifiedRequest() { void testUnidentifiedRequest() {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID)) .target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request() .queryParam("pq", "true")
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes())) .request()
.get(PreKeyResponse.class); .header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevicesCount()).isEqualTo(1); assertThat(result.getDevicesCount()).isEqualTo(1);
assertThat(result.getDevice(1).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertKeysMatch(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertKeysMatch(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); assertKeysMatch(existsAccount.getDevice(1).get().getSignedPreKey(), result.getDevice(1).getSignedPreKey());
verify(KEYS).take(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -422,59 +497,118 @@ class KeysControllerTest {
@Test @Test
void validMultiRequestTestV2() { void validMultiRequestTestV2() {
when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.take(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2));
when(KEYS.take(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.take(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
PreKeyResponse results = resources.getJerseyTest() PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey(); PreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId(); long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId(); long deviceId = results.getDevice(1).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertKeysMatch(SAMPLE_KEY, preKey);
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY.getKeyId()); assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY.getPublicKey());
assertThat(deviceId).isEqualTo(1); assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey(); signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey(); preKey = results.getDevice(2).getPreKey();
registrationId = results.getDevice(2).getRegistrationId(); registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId(); deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId()); assertKeysMatch(SAMPLE_KEY2, preKey);
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY2.getKeyId()); assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY2.getPublicKey());
assertThat(deviceId).isEqualTo(2); assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey(); signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey(); preKey = results.getDevice(4).getPreKey();
registrationId = results.getDevice(4).getRegistrationId(); registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId(); deviceId = results.getDevice(4).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId()); assertKeysMatch(SAMPLE_KEY4, preKey);
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull(); assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4); assertThat(deviceId).isEqualTo(4);
verify(KEYS).take(EXISTS_UUID, 1); verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).take(EXISTS_UUID, 2); verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).take(EXISTS_UUID, 3); verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).take(EXISTS_UUID, 4); verifyNoMoreInteractions(KEYS);
}
@Test
void validMultiRequestPqTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_PQ_KEY));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_PQ_KEY2));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_PQ_KEY3));
when(KEYS.takePQ(EXISTS_UUID, 4)).thenReturn(Optional.<SignedPreKey>empty());
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.queryParam("pq", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKey signedPreKey = results.getDevice(1).getSignedPreKey();
PreKey preKey = results.getDevice(1).getPreKey();
SignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
assertKeysMatch(SAMPLE_KEY, preKey);
assertKeysMatch(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertKeysMatch(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
pqPreKey = results.getDevice(2).getPqPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
assertThat(preKey).isNull();
assertKeysMatch(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertKeysMatch(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
pqPreKey = results.getDevice(4).getPqPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
assertKeysMatch(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takePQ(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).takePQ(EXISTS_UUID, 4);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -523,16 +657,12 @@ class KeysControllerTest {
@Test @Test
void putKeysTestV2() { void putKeysTestV2() {
final PreKey preKey = new PreKey(31337, "foobar"); final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
List<PreKey> preKeys = new LinkedList<PreKey>() {{ PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
add(preKey);
}};
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys);
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -544,12 +674,41 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture()); verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue(); assertThat(listCaptor.getValue()).containsExactly(preKey);
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar"); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void putKeysPqTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
@ -558,13 +717,12 @@ class KeysControllerTest {
@Test @Test
void putKeysByPhoneNumberIdentifierTestV2() { void putKeysByPhoneNumberIdentifierTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair); final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair); final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
List<PreKey> preKeys = List.of(new PreKey(31337, "foobar")); PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey));
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, preKeys);
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -577,12 +735,42 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture()); verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue(); assertThat(listCaptor.getValue()).containsExactly(preKey);
assertThat(capturedList.size()).isEqualTo(1);
assertThat(capturedList.get(0).getKeyId()).isEqualTo(31337); verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
assertThat(capturedList.get(0).getPublicKey()).isEqualTo("foobar"); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
void putKeysByPhoneNumberIdentifierPqTestV2() {
final PreKey preKey = new PreKey(31337, "foobar");
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final SignedPreKey signedPreKey = KeysHelper.signedPreKey(31338, identityKeyPair);
final SignedPreKey pqPreKey = KeysHelper.signedPreKey(31339, identityKeyPair);
final SignedPreKey pqLastResortPreKey = KeysHelper.signedPreKey(31340, identityKeyPair);
final String identityKey = KeysHelper.serializeIdentityKey(identityKeyPair);
PreKeyState preKeyState = new PreKeyState(identityKey, signedPreKey, List.of(preKey), List.of(pqPreKey), pqLastResortPreKey);
Response response =
resources.getJerseyTest()
.target("/v2/keys")
.queryParam("identity", "pni")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<SignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey)); verify(AuthHelper.VALID_ACCOUNT).setPhoneNumberIdentityKey(eq(identityKey));
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey));
@ -627,7 +815,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor<List<PreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture()); verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), isNull());
List<PreKey> capturedList = listCaptor.getValue(); List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1); assertThat(capturedList.size()).isEqualTo(1);
@ -657,4 +845,13 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
} }
private void assertKeysMatch(PreKey expected, PreKey actual) {
assertThat(actual.getKeyId()).isEqualTo(expected.getKeyId());
assertThat(actual.getPublicKey()).isEqualTo(expected.getPublicKey());
if (expected instanceof final SignedPreKey signedExpected) {
final SignedPreKey signedActual = (SignedPreKey) actual;
assertThat(signedActual.getSignature()).isEqualTo(signedExpected.getSignature());
}
}
} }