Remove signed pre-keys from `Device` entities

This commit is contained in:
Jon Chambers 2023-12-08 18:43:35 -05:00 committed by Jon Chambers
parent 394f9929ad
commit b048b0bf65
14 changed files with 123 additions and 233 deletions

View File

@ -15,9 +15,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -41,6 +39,7 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@ -60,8 +59,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys") @Path("/v2/keys")
@ -69,16 +66,16 @@ import reactor.core.publisher.Mono;
public class KeysController { public class KeysController {
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final KeysManager keys; private final KeysManager keysManager;
private final AccountsManager accounts; private final AccountsManager accounts;
private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys"); private static final String GET_KEYS_COUNTER_NAME = MetricsUtil.name(KeysController.class, "getKeys");
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];
public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) { public KeysController(RateLimiters rateLimiters, KeysManager keysManager, AccountsManager accounts) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.keys = keys; this.keysManager = keysManager;
this.accounts = accounts; this.accounts = accounts;
} }
@ -92,10 +89,10 @@ public class KeysController {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture = final CompletableFuture<Integer> ecCountFuture =
keys.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); keysManager.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> pqCountFuture = final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); keysManager.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
} }
@ -124,43 +121,25 @@ public class KeysController {
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType)); checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType));
final CompletableFuture<Account> updateAccountFuture; final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
if (setKeysRequest.signedPreKey() != null && if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) {
!setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) { storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
updateAccountFuture = accounts.updateDeviceTransactionallyAsync(account,
device.getId(),
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey());
case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey());
}
},
d -> List.of(keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), setKeysRequest.signedPreKey())))
.toCompletableFuture();
} else {
updateAccountFuture = CompletableFuture.completedFuture(account);
} }
return updateAccountFuture.thenCompose(updatedAccount -> { if (setKeysRequest.signedPreKey() != null) {
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(3); storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
if (setKeysRequest.preKeys() != null && !setKeysRequest.preKeys().isEmpty()) { if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) {
storeFutures.add(keys.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys())); storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
} }
if (setKeysRequest.pqPreKeys() != null && !setKeysRequest.pqPreKeys().isEmpty()) { if (setKeysRequest.pqLastResortPreKey() != null) {
storeFutures.add(keys.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys())); storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
} }
if (setKeysRequest.pqLastResortPreKey() != null) { return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
storeFutures.add(
keys.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY));
})
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }
@ -240,28 +219,41 @@ public class KeysController {
io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId))))) io.micrometer.core.instrument.Tag.of("wildcardDeviceId", String.valueOf("*".equals(deviceId)))))
.increment(); .increment();
final List<PreKeyResponseItem> responseItems = Flux.fromIterable(parseDeviceId(deviceId, target)) final List<Device> devices = parseDeviceId(deviceId, target);
.flatMap(device -> Mono.zip( final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
Mono.just(device),
Mono.fromFuture(() -> keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())), final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> {
Mono.fromFuture(() -> keys.takeEC(targetIdentifier.uuid(), device.getId())), final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture =
Mono.fromFuture(() -> returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()) keysManager.takeEC(targetIdentifier.uuid(), device.getId());
: CompletableFuture.<Optional<KEMSignedPreKey>>completedFuture(Optional.empty()))
)).filter(keys -> keys.getT2().isPresent() || keys.getT3().isPresent() || keys.getT4().isPresent()) final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture =
.map(deviceAndKeys -> { keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId());
final Device device = deviceAndKeys.getT1();
final int registrationId = switch (targetIdentifier.identityType()) { final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
case ACI -> device.getRegistrationId(); ? keysManager.takePQ(targetIdentifier.uuid(), device.getId())
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); : CompletableFuture.completedFuture(Optional.empty());
};
return new PreKeyResponseItem(device.getId(), registrationId, return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
deviceAndKeys.getT2().orElse(null), .thenAccept(ignored -> {
deviceAndKeys.getT3().orElse(null), final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
deviceAndKeys.getT4().orElse(null)); final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null);
}).collectList() final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null);
.timeout(Duration.ofSeconds(30))
.blockOptional() if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
.orElse(Collections.emptyList()); final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
pqPreKey));
}
});
})
.toList();
CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType()); final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());
@ -289,16 +281,7 @@ public class KeysController {
final UUID identifier = auth.getAccount().getIdentifier(identityType); final UUID identifier = auth.getAccount().getIdentifier(identityType);
final byte deviceId = auth.getAuthenticatedDevice().getId(); final byte deviceId = auth.getAuthenticatedDevice().getId();
return accounts.updateDeviceTransactionallyAsync(auth.getAccount(), return keysManager.storeEcSignedPreKeys(identifier, deviceId, signedPreKey)
deviceId,
d -> {
switch (identityType) {
case ACI -> d.setSignedPreKey(signedPreKey);
case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey);
}
},
d -> List.of(keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), signedPreKey)))
.toCompletableFuture()
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }

View File

@ -12,7 +12,6 @@ import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer;
import org.signal.chat.common.EcPreKey; import org.signal.chat.common.EcPreKey;
import org.signal.chat.common.EcSignedPreKey; import org.signal.chat.common.EcSignedPreKey;
import org.signal.chat.common.KemSignedPreKey; import org.signal.chat.common.KemSignedPreKey;
@ -40,7 +39,6 @@ import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -191,18 +189,9 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
KeysGrpcService::checkEcSignedPreKey, KeysGrpcService::checkEcSignedPreKey,
(account, signedPreKey) -> { (account, signedPreKey) -> {
final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType()); final IdentityType identityType = IdentityTypeUtil.fromGrpcIdentityType(request.getIdentityType());
final Consumer<Device> deviceUpdater = switch (identityType) {
case ACI -> device -> device.setSignedPreKey(signedPreKey);
case PNI -> device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
};
final UUID identifier = account.getIdentifier(identityType); final UUID identifier = account.getIdentifier(identityType);
return Flux.merge( return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey));
Mono.fromFuture(() -> keysManager.storeEcSignedPreKeys(identifier, authenticatedDevice.deviceId(), signedPreKey)),
Mono.fromFuture(() -> accountsManager.updateDeviceAsync(account, authenticatedDevice.deviceId(), deviceUpdater)))
.then();
})); }));
} }

View File

@ -30,7 +30,6 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Queue; import java.util.Queue;
import java.util.UUID; import java.util.UUID;
@ -416,7 +415,7 @@ public class AccountsManager {
final Account numberChangedAccount = updateWithRetries( final Account numberChangedAccount = updateWithRetries(
account, account,
a -> { a -> {
setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); setPniKeys(account, pniIdentityKey, pniRegistrationIds);
return true; return true;
}, },
a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems), a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
@ -445,7 +444,7 @@ public class AccountsManager {
return redisDeleteAsync(account) return redisDeleteAsync(account)
.thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni)) .thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni))
.thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account, .thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account,
a -> setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), a -> setPniKeys(a, pniIdentityKey, pniRegistrationIds),
accounts::updateTransactionallyAsync, accounts::updateTransactionallyAsync,
() -> accounts.getByAccountIdentifierAsync(aci).thenApply(Optional::orElseThrow), () -> accounts.getByAccountIdentifierAsync(aci).thenApply(Optional::orElseThrow),
a -> keyWriteItems, a -> keyWriteItems,
@ -483,28 +482,18 @@ public class AccountsManager {
private void setPniKeys(final Account account, private void setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey, @Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) { @Nullable final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
if (ObjectUtils.allNull(pniIdentityKey, pniRegistrationIds)) {
return; return;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniRegistrationIds)) {
throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); throw new IllegalArgumentException("PNI identity key and registration IDs must be all null or all non-null");
} }
boolean changed = !Objects.equals(pniIdentityKey, account.getIdentityKey(IdentityType.PNI)); account.getDevices()
.stream()
for (Device device : account.getDevices()) { .filter(Device::isEnabled)
if (!device.isEnabled()) { .forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
continue;
}
ECSignedPreKey signedPreKey = pniSignedPreKeys.get(device.getId());
int registrationId = pniRegistrationIds.get(device.getId());
changed = changed ||
!signedPreKey.equals(device.getSignedPreKey(IdentityType.PNI)) ||
device.getRegistrationId() != registrationId;
device.setPhoneNumberIdentitySignedPreKey(signedPreKey);
device.setPhoneNumberIdentityRegistrationId(registrationId);
}
account.setPhoneNumberIdentityKey(pniIdentityKey); account.setPhoneNumberIdentityKey(pniIdentityKey);
} }

View File

@ -17,8 +17,6 @@ import java.util.stream.IntStream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter; import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
public class Device { public class Device {
@ -72,12 +70,6 @@ public class Device {
@JsonProperty("pniRegistrationId") @JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId; private Integer phoneNumberIdentityRegistrationId;
@JsonProperty
private ECSignedPreKey signedPreKey;
@JsonProperty("pniSignedPreKey")
private ECSignedPreKey phoneNumberIdentitySignedPreKey;
@JsonProperty @JsonProperty
private long lastSeen; private long lastSeen;
@ -247,25 +239,6 @@ public class Device {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId; this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
} }
/**
* @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead
*/
@Deprecated
public ECSignedPreKey getSignedPreKey(final IdentityType identityType) {
return switch (identityType) {
case ACI -> signedPreKey;
case PNI -> phoneNumberIdentitySignedPreKey;
};
}
public void setSignedPreKey(ECSignedPreKey signedPreKey) {
this.signedPreKey = signedPreKey;
}
public void setPhoneNumberIdentitySignedPreKey(final ECSignedPreKey phoneNumberIdentitySignedPreKey) {
this.phoneNumberIdentitySignedPreKey = phoneNumberIdentitySignedPreKey;
}
public long getPushTimestamp() { public long getPushTimestamp() {
return pushTimestamp; return pushTimestamp;
} }

View File

@ -38,8 +38,6 @@ public record DeviceSpec(
device.setCreated(clock.millis()); device.setCreated(clock.millis());
device.setLastSeen(Util.todayInMillis()); device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent()); device.setUserAgent(signalAgent());
device.setSignedPreKey(aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey());
apnRegistrationId().ifPresent(apnRegistrationId -> { apnRegistrationId().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId()); device.setApnId(apnRegistrationId.apnRegistrationId());

View File

@ -846,8 +846,6 @@ class AccountControllerV2Test {
device.setId(deviceData.id); device.setId(deviceData.id);
device.setAuthTokenHash(passwordTokenHash); device.setAuthTokenHash(passwordTokenHash);
device.setFetchesMessages(true); device.setFetchesMessages(true);
device.setSignedPreKey(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
device.setPhoneNumberIdentitySignedPreKey(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
device.setLastSeen(deviceData.lastSeen().toEpochMilli()); device.setLastSeen(deviceData.lastSeen().toEpochMilli());
device.setCreated(deviceData.created().toEpochMilli()); device.setCreated(deviceData.created().toEpochMilli());
device.setUserAgent(deviceData.userAgent()); device.setUserAgent(deviceData.userAgent());

View File

@ -232,8 +232,6 @@ class DeviceControllerTest {
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock); final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
assertEquals(fetchesMessages, device.getFetchesMessages()); assertEquals(fetchesMessages, device.getFetchesMessages());
expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()), expectedApnsToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getApnId()),

View File

@ -319,36 +319,33 @@ class KeysControllerTest {
@Test @Test
void putSignedPreKeyV2() { void putSignedPreKeyV2() {
ECSignedPreKey test = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR); final ECSignedPreKey signedPreKey = KeysHelper.signedECPreKey(9998, IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest()
try (final Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(test, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(signedPreKey, MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); }
verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any());
} }
@Test @Test
void putPhoneNumberIdentitySignedPreKeyV2() { void putPhoneNumberIdentitySignedPreKeyV2() {
final ECSignedPreKey replacementKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR); final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(9998, PNI_IDENTITY_KEY_PAIR);
Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v2/keys/signed") .target("/v2/keys/signed")
.queryParam("identity", "pni") .queryParam("identity", "pni")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(replacementKey, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(pniSignedPreKey, MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId(), pniSignedPreKey);
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey)); }
verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any());
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any());
} }
@Test @Test
@ -761,8 +758,7 @@ class KeysControllerTest {
assertThat(listCaptor.getValue()).containsExactly(preKey); assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
} }
@Test @Test
@ -786,9 +782,7 @@ class KeysControllerTest {
verify(KEYS, never()).storeEcOneTimePreKeys(any(), anyByte(), any()); verify(KEYS, never()).storeEcOneTimePreKeys(any(), anyByte(), any());
verify(KEYS, never()).storeKemOneTimePreKeys(any(), anyByte(), any()); verify(KEYS, never()).storeKemOneTimePreKeys(any(), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
} }
} }
@ -824,8 +818,7 @@ class KeysControllerTest {
assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
} }
@Test @Test
@ -926,8 +919,7 @@ class KeysControllerTest {
assertThat(listCaptor.getValue()).containsExactly(preKey); assertThat(listCaptor.getValue()).containsExactly(preKey);
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
} }
@Test @Test
@ -963,8 +955,7 @@ class KeysControllerTest {
assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId(), signedPreKey);
verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any());
} }
@Test @Test

View File

@ -100,7 +100,6 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicInboundMessageByteLimitConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices; import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices; import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
@ -133,7 +132,6 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -218,13 +216,13 @@ class MessageControllerTest {
final List<Device> singleDeviceList = List.of( final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()) generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, System.currentTimeMillis(), System.currentTimeMillis())
); );
final List<Device> multiDeviceList = List.of( final List<Device> multiDeviceList = List.of(
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, KeysHelper.signedECPreKey(111, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, KeysHelper.signedECPreKey(222, identityKeyPair), System.currentTimeMillis(), System.currentTimeMillis()), generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)) generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
); );
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
@ -265,12 +263,11 @@ class MessageControllerTest {
} }
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId, private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) { final long createdAt, final long lastSeen) {
final Device device = new Device(); final Device device = new Device();
device.setId(id); device.setId(id);
device.setRegistrationId(registrationId); device.setRegistrationId(registrationId);
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId); device.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
device.setSignedPreKey(signedPreKey);
device.setCreated(createdAt); device.setCreated(createdAt);
device.setLastSeen(lastSeen); device.setLastSeen(lastSeen);
device.setGcmId("isgcm"); device.setGcmId("isgcm");
@ -1045,7 +1042,7 @@ class MessageControllerTest {
IntStream.range(1, devicesPerRecipient + 1) IntStream.range(1, devicesPerRecipient + 1)
.mapToObj( .mapToObj(
d -> generateTestDevice( d -> generateTestDevice(
(byte) d, 100 + d, 10 * d, KeysHelper.signedECPreKey(333, identityKeyPair), System.currentTimeMillis(), (byte) d, 100 + d, 10 * d, System.currentTimeMillis(),
System.currentTimeMillis())) System.currentTimeMillis()))
.collect(Collectors.toList()); .collect(Collectors.toList());
final UUID aci = new UUID(0L, (long) i); final UUID aci = new UUID(0L, (long) i);

View File

@ -323,17 +323,13 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.build()) .build())
.build()); .build());
switch (identityType) { final UUID expectedIdentifier = switch (identityType) {
case IDENTITY_TYPE_ACI -> { case IDENTITY_TYPE_ACI -> AUTHENTICATED_ACI;
verify(authenticatedDevice).setSignedPreKey(signedPreKey); case IDENTITY_TYPE_PNI -> AUTHENTICATED_PNI;
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_ACI, AUTHENTICATED_DEVICE_ID, signedPreKey); default -> throw new IllegalArgumentException("Unexpected identity type");
} };
case IDENTITY_TYPE_PNI -> { verify(keysManager).storeEcSignedPreKeys(expectedIdentifier, AUTHENTICATED_DEVICE_ID, signedPreKey);
verify(authenticatedDevice).setPhoneNumberIdentitySignedPreKey(signedPreKey);
verify(keysManager).storeEcSignedPreKeys(AUTHENTICATED_PNI, AUTHENTICATED_DEVICE_ID, signedPreKey);
}
}
} }
@ParameterizedTest @ParameterizedTest

View File

@ -63,6 +63,7 @@ class AccountsManagerChangeNumberIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private KeysManager keysManager;
private ClientPresenceManager clientPresenceManager; private ClientPresenceManager clientPresenceManager;
private ExecutorService accountLockExecutor; private ExecutorService accountLockExecutor;
private ExecutorService clientPresenceExecutor; private ExecutorService clientPresenceExecutor;
@ -79,7 +80,7 @@ class AccountsManagerChangeNumberIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final KeysManager keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(), Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(), Tables.PQ_KEYS.tableName(),
@ -189,7 +190,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes);
account.getPrimaryDevice().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair)); keysManager.storeEcSignedPreKeys(account.getIdentifier(IdentityType.ACI),
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair)).join();
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
@ -216,7 +218,8 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(OptionalInt.of(rotatedPniRegistrationId), assertEquals(OptionalInt.of(rotatedPniRegistrationId),
updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId()); updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
assertEquals(rotatedSignedPreKey, updatedAccount.getPrimaryDevice().getSignedPreKey(IdentityType.PNI)); assertEquals(Optional.of(rotatedSignedPreKey),
keysManager.getEcSignedPreKey(updatedAccount.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
} }
@Test @Test

View File

@ -18,6 +18,7 @@ import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
@ -48,7 +49,6 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -670,7 +670,6 @@ class AccountsManagerTest {
Device enabledDevice = new Device(); Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true); enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis()); enabledDevice.setLastSeen(System.currentTimeMillis());
final byte deviceId = account.getNextDeviceId(); final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId); enabledDevice.setId(deviceId);
@ -703,7 +702,6 @@ class AccountsManagerTest {
Device enabledDevice = new Device(); Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true); enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis()); enabledDevice.setLastSeen(System.currentTimeMillis());
final byte deviceId = account.getNextDeviceId(); final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId); enabledDevice.setId(deviceId);
@ -747,6 +745,7 @@ class AccountsManagerTest {
assertFalse(account.getDevice(linkedDevice.getId()).isPresent()); assertFalse(account.getDevice(linkedDevice.getId()).isPresent());
verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId()); verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId());
verify(keysManager, times(2)).deleteSingleUsePreKeys(account.getUuid(), linkedDevice.getId()); verify(keysManager, times(2)).deleteSingleUsePreKeys(account.getUuid(), linkedDevice.getId());
verify(keysManager).buildWriteItemsForRemovedDevice(account.getUuid(), account.getPhoneNumberIdentifier(), linkedDevice.getId());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId()); verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId());
} }
@ -775,9 +774,17 @@ class AccountsManagerTest {
final String e164 = "+18005550123"; final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null); final AccountAttributes attributes = new AccountAttributes(false, 1, 2, null, null, true, null);
createAccount(e164, attributes); final Account createdAccount = createAccount(e164, attributes);
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any()); verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any());
verify(keysManager).buildWriteItemsForNewDevice(
eq(createdAccount.getUuid()),
eq(createdAccount.getPhoneNumberIdentifier()),
eq(Device.PRIMARY_ID),
notNull(),
notNull(),
notNull(),
notNull());
verifyNoInteractions(messagesManager); verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager); verifyNoInteractions(profilesManager);
@ -806,13 +813,22 @@ class AccountsManagerTest {
when(accounts.reclaimAccount(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.reclaimAccount(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
createAccount(e164, attributes); final Account reregisteredAccount = createAccount(e164, attributes);
assertTrue(phoneNumberIdentifiersByE164.containsKey(e164)); assertTrue(phoneNumberIdentifiersByE164.containsKey(e164));
verify(accounts) verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any()); .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any());
verify(keysManager).buildWriteItemsForNewDevice(
eq(reregisteredAccount.getUuid()),
eq(reregisteredAccount.getPhoneNumberIdentifier()),
eq(Device.PRIMARY_ID),
notNull(),
notNull(),
notNull(),
notNull());
verify(keysManager, times(2)).deleteSingleUsePreKeys(existingUuid); verify(keysManager, times(2)).deleteSingleUsePreKeys(existingUuid);
verify(keysManager, times(2)).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(e164)); verify(keysManager, times(2)).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(e164));
verify(messagesManager, times(2)).clear(existingUuid); verify(messagesManager, times(2)).clear(existingUuid);
@ -943,8 +959,6 @@ class AccountsManagerTest {
assertNull(device.getApnId()); assertNull(device.getApnId());
assertNull(device.getVoipApnId()); assertNull(device.getVoipApnId());
assertNull(device.getGcmId()); assertNull(device.getGcmId());
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
} }
@ParameterizedTest @ParameterizedTest
@ -1142,9 +1156,6 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1154,8 +1165,6 @@ class AccountsManagerTest {
UUID oldUuid = account.getUuid(); UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier(); UUID oldPni = account.getPhoneNumberIdentifier();
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1169,16 +1178,11 @@ class AccountsManagerTest {
assertEquals(number, updatedAccount.getNumber()); assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI stuff should // PNI stuff should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newSignedKeys,
updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.PNI))));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
@ -1198,9 +1202,6 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1219,9 +1220,6 @@ class AccountsManagerTest {
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount = final Account updatedAccount =
@ -1232,16 +1230,11 @@ class AccountsManagerTest {
assertEquals(number, updatedAccount.getNumber()); assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should // PNI keys should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newSignedKeys,
updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.PNI))));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
@ -1262,9 +1255,6 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1282,9 +1272,6 @@ class AccountsManagerTest {
when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Account updatedAccount = final Account updatedAccount =
@ -1295,16 +1282,11 @@ class AccountsManagerTest {
assertEquals(number, updatedAccount.getNumber()); assertEquals(number, updatedAccount.getNumber());
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier()); assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should // PNI keys should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newSignedKeys,
updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.PNI))));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt())));
@ -1524,8 +1506,6 @@ class AccountsManagerTest {
final Device parsedDevice = parsedAccount.getPrimaryDevice(); final Device parsedDevice = parsedAccount.getPrimaryDevice();
assertEquals(originalDevice.getId(), parsedDevice.getId()); assertEquals(originalDevice.getId(), parsedDevice.getId());
assertEquals(originalDevice.getSignedPreKey(IdentityType.ACI), parsedDevice.getSignedPreKey(IdentityType.ACI));
assertEquals(originalDevice.getSignedPreKey(IdentityType.PNI), parsedDevice.getSignedPreKey(IdentityType.PNI));
assertEquals(originalDevice.getRegistrationId(), parsedDevice.getRegistrationId()); assertEquals(originalDevice.getRegistrationId(), parsedDevice.getRegistrationId());
assertEquals(originalDevice.getPhoneNumberIdentityRegistrationId(), assertEquals(originalDevice.getPhoneNumberIdentityRegistrationId(),
parsedDevice.getPhoneNumberIdentityRegistrationId()); parsedDevice.getPhoneNumberIdentityRegistrationId());
@ -1541,7 +1521,6 @@ class AccountsManagerTest {
final Device device = new Device(); final Device device = new Device();
device.setId(Device.PRIMARY_ID); device.setId(Device.PRIMARY_ID);
device.setFetchesMessages(true); device.setFetchesMessages(true);
device.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
device.setLastSeen(lastSeen); device.setLastSeen(lastSeen);
return device; return device;

View File

@ -52,7 +52,6 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@ -1372,8 +1371,6 @@ class AccountsTest {
assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId()); assertThat(resultDevice.getApnId()).isEqualTo(expectingDevice.getApnId());
assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId()); assertThat(resultDevice.getGcmId()).isEqualTo(expectingDevice.getGcmId());
assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen()); assertThat(resultDevice.getLastSeen()).isEqualTo(expectingDevice.getLastSeen());
assertThat(resultDevice.getSignedPreKey(IdentityType.ACI)).isEqualTo(
expectingDevice.getSignedPreKey(IdentityType.ACI));
assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages()); assertThat(resultDevice.getFetchesMessages()).isEqualTo(expectingDevice.getFetchesMessages());
assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent()); assertThat(resultDevice.getUserAgent()).isEqualTo(expectingDevice.getUserAgent());
assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName()); assertThat(resultDevice.getName()).isEqualTo(expectingDevice.getName());

View File

@ -47,7 +47,6 @@ public class DevicesHelper {
public static void setEnabled(Device device, boolean enabled) { public static void setEnabled(Device device, boolean enabled) {
if (enabled) { if (enabled) {
device.setPhoneNumberIdentitySignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));
device.setGcmId("testGcmId" + RANDOM.nextLong()); device.setGcmId("testGcmId" + RANDOM.nextLong());
device.setLastSeen(Util.todayInMillis()); device.setLastSeen(Util.todayInMillis());
} else { } else {