Make `getRegistrationId` identity-type-aware

This commit is contained in:
Jon Chambers 2025-05-07 11:01:17 -04:00 committed by Jon Chambers
parent 13fc0ffbca
commit 9ec66dac7f
12 changed files with 45 additions and 57 deletions

View File

@ -381,10 +381,7 @@ public class KeysController {
.increment();
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId();
};
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,

View File

@ -24,7 +24,6 @@ import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
import org.signal.libsignal.protocol.util.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
@ -36,7 +35,6 @@ import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Util;
@ -318,11 +316,7 @@ public class MessageSender {
// We know the device must be present because we've already filtered out device IDs that aren't associated
// with the given account
final Device device = account.getDevice(deviceId).orElseThrow();
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId();
};
final int expectedRegistrationId = device.getRegistrationId(serviceIdentifier.identityType());
return registrationId != expectedRegistrationId;
})

View File

@ -61,10 +61,8 @@ public class ReceiptSender {
.collect(Collectors.toMap(Device::getId, ignored -> message));
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId();
}));
.collect(Collectors.toMap(Device::getId,
device -> device.getRegistrationId(destinationIdentifier.identityType())));
try {
messageSender.sendMessages(destinationAccount,

View File

@ -20,6 +20,7 @@ import java.util.stream.IntStream;
import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
@ -207,18 +208,17 @@ public class Device {
return getId() == PRIMARY_ID;
}
public int getRegistrationId() {
return registrationId;
public int getRegistrationId(final IdentityType identityType) {
return switch (identityType) {
case ACI -> registrationId;
case PNI -> phoneNumberIdentityRegistrationId;
};
}
public void setRegistrationId(int registrationId) {
this.registrationId = registrationId;
}
public int getPhoneNumberIdentityRegistrationId() {
return phoneNumberIdentityRegistrationId;
}
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}

View File

@ -214,11 +214,11 @@ class KeysControllerTest {
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
when(sampleDevice.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);

View File

@ -1244,15 +1244,15 @@ class MessageControllerTest {
final Device singleDeviceAccountPrimary = mock(Device.class);
when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId);
when(singleDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(singleDevicePrimaryRegistrationId);
final Device multiDeviceAccountPrimary = mock(Device.class);
when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId);
when(multiDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(multiDevicePrimaryRegistrationId);
final Device multiDeviceAccountLinked = mock(Device.class);
when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId);
when(multiDeviceAccountLinked.getRegistrationId(IdentityType.ACI)).thenReturn(multiDeviceLinkedRegistrationId);
final Account singleDeviceAccount = mock(Account.class);
when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci);
@ -1662,15 +1662,15 @@ class MessageControllerTest {
final Device singleDeviceAccountPrimary = mock(Device.class);
when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId);
when(singleDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(singleDevicePrimaryRegistrationId);
final Device multiDeviceAccountPrimary = mock(Device.class);
when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId);
when(multiDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(multiDevicePrimaryRegistrationId);
final Device multiDeviceAccountLinked = mock(Device.class);
when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId);
when(multiDeviceAccountLinked.getRegistrationId(IdentityType.ACI)).thenReturn(multiDeviceLinkedRegistrationId);
final Account singleDeviceAccount = mock(Account.class);
when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci);

View File

@ -125,13 +125,13 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest<MessagesGrpcService, Me
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty()));
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
when(authenticatedDevice.getRegistrationId()).thenReturn(AUTHENTICATED_REGISTRATION_ID);
when(authenticatedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(AUTHENTICATED_REGISTRATION_ID);
when(linkedDevice.getId()).thenReturn(LINKED_DEVICE_ID);
when(linkedDevice.getRegistrationId()).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
when(secondLinkedDevice.getId()).thenReturn(SECOND_LINKED_DEVICE_ID);
when(secondLinkedDevice.getRegistrationId()).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
when(secondLinkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);

View File

@ -92,7 +92,7 @@ class MessageSenderTest {
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
if (hasPushToken) {
when(device.getApnId()).thenReturn("apns-token");
@ -140,7 +140,7 @@ class MessageSenderTest {
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final MismatchedDevicesException mismatchedDevicesException =
@ -178,7 +178,7 @@ class MessageSenderTest {
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
if (hasPushToken) {
@ -230,7 +230,7 @@ class MessageSenderTest {
when(account.getDevices()).thenReturn(List.of(device));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
when(device.getRegistrationId()).thenReturn(registrationId);
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
when(device.getApnId()).thenReturn("apns-token");
final SealedSenderMultiRecipientMessage multiRecipientMessage =
@ -291,13 +291,13 @@ class MessageSenderTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(primaryDevicePniRegistrationId);
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryDeviceAciRegistrationId);
when(primaryDevice.getRegistrationId(IdentityType.PNI)).thenReturn(primaryDevicePniRegistrationId);
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(linkedDevicePniRegistrationId);
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceAciRegistrationId);
when(linkedDevice.getRegistrationId(IdentityType.PNI)).thenReturn(linkedDevicePniRegistrationId);
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));

View File

@ -506,8 +506,8 @@ public class AccountCreationDeletionIntegrationTest {
assertEquals(number, account.getNumber());
assertEquals(signalAgent, primaryDevice.getUserAgent());
assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages());
assertEquals(registrationId, primaryDevice.getRegistrationId());
assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId());
assertEquals(registrationId, primaryDevice.getRegistrationId(IdentityType.ACI));
assertEquals(pniRegistrationId, primaryDevice.getRegistrationId(IdentityType.PNI));
assertArrayEquals(deviceName, primaryDevice.getName());
assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber());
assertEquals(deviceCapabilities, primaryDevice.getCapabilities());

View File

@ -17,7 +17,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
@ -232,7 +231,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(rotatedPniRegistrationId, updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
assertEquals(rotatedPniRegistrationId, updatedAccount.getPrimaryDevice().getRegistrationId(IdentityType.PNI));
assertEquals(Optional.of(rotatedSignedPreKey),
keysManager.getEcSignedPreKey(updatedAccount.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());

View File

@ -1012,8 +1012,8 @@ class AccountsManagerTest {
assertTrue(device.getAuthTokenHash().verify(password));
assertEquals(signalAgent, device.getUserAgent());
assertEquals(Collections.emptySet(), device.getCapabilities());
assertEquals(aciRegistrationId, device.getRegistrationId());
assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId());
assertEquals(aciRegistrationId, device.getRegistrationId(IdentityType.ACI));
assertEquals(pniRegistrationId, device.getRegistrationId(IdentityType.PNI));
assertTrue(device.getFetchesMessages());
assertNull(device.getApnId());
assertNull(device.getGcmId());
@ -1265,12 +1265,12 @@ class AccountsManagerTest {
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, device -> device.getRegistrationId(IdentityType.ACI))));
// PNI stuff should
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(newRegistrationIds,
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentityRegistrationId)));
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, device -> device.getRegistrationId(IdentityType.PNI))));
verify(accounts).updateTransactionallyAsync(any(), any());
@ -1509,9 +1509,8 @@ class AccountsManagerTest {
final Device parsedDevice = parsedAccount.getPrimaryDevice();
assertEquals(originalDevice.getId(), parsedDevice.getId());
assertEquals(originalDevice.getRegistrationId(), parsedDevice.getRegistrationId());
assertEquals(originalDevice.getPhoneNumberIdentityRegistrationId(),
parsedDevice.getPhoneNumberIdentityRegistrationId());
assertEquals(originalDevice.getRegistrationId(IdentityType.ACI), parsedDevice.getRegistrationId(IdentityType.ACI));
assertEquals(originalDevice.getRegistrationId(IdentityType.PNI), parsedDevice.getRegistrationId(IdentityType.PNI));
assertEquals(originalDevice.getCapabilities(), parsedDevice.getCapabilities());
assertEquals(originalDevice.getFetchesMessages(), parsedDevice.getFetchesMessages());
}

View File

@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
@ -141,13 +142,13 @@ public class ChangeNumberManagerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(primaryDevice.getRegistrationId()).thenReturn(7);
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(7);
final Device linkedDevice = mock(Device.class);
final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
@ -401,7 +402,7 @@ public class ChangeNumberManagerTest {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
when(device.getRegistrationId(IdentityType.ACI)).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));