Assume that PNI registration IDs are always present on `Device` records

This commit is contained in:
Jon Chambers 2025-05-02 12:28:55 -04:00 committed by Jon Chambers
parent 93ba6616d1
commit 13fc0ffbca
9 changed files with 13 additions and 42 deletions

View File

@ -383,7 +383,7 @@ public class KeysController {
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) { if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) { final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId(); case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); case PNI -> device.getPhoneNumberIdentityRegistrationId();
}; };
responseItems.add( responseItems.add(

View File

@ -321,7 +321,7 @@ public class MessageSender {
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) { final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId(); case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId); case PNI -> device.getPhoneNumberIdentityRegistrationId();
}; };
return registrationId != expectedRegistrationId; return registrationId != expectedRegistrationId;

View File

@ -63,7 +63,7 @@ public class ReceiptSender {
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream() final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) { .collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId(); case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElseGet(device::getRegistrationId); case PNI -> device.getPhoneNumberIdentityRegistrationId();
})); }));
try { try {

View File

@ -64,9 +64,8 @@ public class Device {
@JsonProperty @JsonProperty
private int registrationId; private int registrationId;
@Nullable
@JsonProperty("pniRegistrationId") @JsonProperty("pniRegistrationId")
private Integer phoneNumberIdentityRegistrationId; private int phoneNumberIdentityRegistrationId;
@JsonProperty @JsonProperty
private long lastSeen; private long lastSeen;
@ -216,8 +215,8 @@ public class Device {
this.registrationId = registrationId; this.registrationId = registrationId;
} }
public OptionalInt getPhoneNumberIdentityRegistrationId() { public int getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId != null ? OptionalInt.of(phoneNumberIdentityRegistrationId) : OptionalInt.empty(); return phoneNumberIdentityRegistrationId;
} }
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) { public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {

View File

@ -38,7 +38,6 @@ import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream; import java.util.stream.IntStream;
@ -219,7 +218,7 @@ class KeysControllerTest {
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4); when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(SAMPLE_PNI_REGISTRATION_ID)); when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
when(sampleDevice.getId()).thenReturn(sampleDeviceId); when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
@ -437,29 +436,6 @@ class KeysControllerTest {
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@Test
void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() {
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.empty());
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/PNI:%s/1", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class);
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@Test @Test
void testGetKeysRateLimited() throws RateLimitExceededException { void testGetKeysRateLimited() throws RateLimitExceededException {
Duration retryAfter = Duration.ofSeconds(31); Duration retryAfter = Duration.ofSeconds(31);

View File

@ -12,7 +12,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -23,7 +22,6 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -294,12 +292,12 @@ class MessageSenderTest {
final Device primaryDevice = mock(Device.class); final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId); when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId); when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(primaryDevicePniRegistrationId)); when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(primaryDevicePniRegistrationId);
final Device linkedDevice = mock(Device.class); final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId); when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId); when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(linkedDevicePniRegistrationId)); when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(linkedDevicePniRegistrationId);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));

View File

@ -507,7 +507,7 @@ public class AccountCreationDeletionIntegrationTest {
assertEquals(signalAgent, primaryDevice.getUserAgent()); assertEquals(signalAgent, primaryDevice.getUserAgent());
assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages()); assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages());
assertEquals(registrationId, primaryDevice.getRegistrationId()); assertEquals(registrationId, primaryDevice.getRegistrationId());
assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId().orElseThrow()); assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId());
assertArrayEquals(deviceName, primaryDevice.getName()); assertArrayEquals(deviceName, primaryDevice.getName());
assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()); assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber());
assertEquals(deviceCapabilities, primaryDevice.getCapabilities()); assertEquals(deviceCapabilities, primaryDevice.getCapabilities());

View File

@ -232,9 +232,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni)); assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(rotatedPniRegistrationId, updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
assertEquals(OptionalInt.of(rotatedPniRegistrationId),
updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
assertEquals(Optional.of(rotatedSignedPreKey), assertEquals(Optional.of(rotatedSignedPreKey),
keysManager.getEcSignedPreKey(updatedAccount.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join()); keysManager.getEcSignedPreKey(updatedAccount.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());

View File

@ -1013,7 +1013,7 @@ class AccountsManagerTest {
assertEquals(signalAgent, device.getUserAgent()); assertEquals(signalAgent, device.getUserAgent());
assertEquals(Collections.emptySet(), device.getCapabilities()); assertEquals(Collections.emptySet(), device.getCapabilities());
assertEquals(aciRegistrationId, device.getRegistrationId()); assertEquals(aciRegistrationId, device.getRegistrationId());
assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId().getAsInt()); assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId());
assertTrue(device.getFetchesMessages()); assertTrue(device.getFetchesMessages());
assertNull(device.getApnId()); assertNull(device.getApnId());
assertNull(device.getGcmId()); assertNull(device.getGcmId());
@ -1270,7 +1270,7 @@ class AccountsManagerTest {
// PNI stuff should // PNI stuff should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI)); assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newRegistrationIds, assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentityRegistrationId)));
verify(accounts).updateTransactionallyAsync(any(), any()); verify(accounts).updateTransactionallyAsync(any(), any());