Make `getRegistrationId` identity-type-aware
This commit is contained in:
parent
13fc0ffbca
commit
9ec66dac7f
|
@ -381,10 +381,7 @@ public class KeysController {
|
|||
.increment();
|
||||
|
||||
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
|
||||
final int registrationId = switch (targetIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId();
|
||||
};
|
||||
final int registrationId = device.getRegistrationId(targetIdentifier.identityType());
|
||||
|
||||
responseItems.add(
|
||||
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
|
||||
|
|
|
@ -24,7 +24,6 @@ import javax.annotation.Nullable;
|
|||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.signal.libsignal.protocol.SealedSenderMultiRecipientMessage;
|
||||
import org.signal.libsignal.protocol.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.controllers.MessageController;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevices;
|
||||
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
|
||||
|
@ -36,7 +35,6 @@ import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
|||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
|
@ -318,11 +316,7 @@ public class MessageSender {
|
|||
// We know the device must be present because we've already filtered out device IDs that aren't associated
|
||||
// with the given account
|
||||
final Device device = account.getDevice(deviceId).orElseThrow();
|
||||
|
||||
final int expectedRegistrationId = switch (serviceIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId();
|
||||
};
|
||||
final int expectedRegistrationId = device.getRegistrationId(serviceIdentifier.identityType());
|
||||
|
||||
return registrationId != expectedRegistrationId;
|
||||
})
|
||||
|
|
|
@ -61,10 +61,8 @@ public class ReceiptSender {
|
|||
.collect(Collectors.toMap(Device::getId, ignored -> message));
|
||||
|
||||
final Map<Byte, Integer> registrationIdsByDeviceId = destinationAccount.getDevices().stream()
|
||||
.collect(Collectors.toMap(Device::getId, device -> switch (destinationIdentifier.identityType()) {
|
||||
case ACI -> device.getRegistrationId();
|
||||
case PNI -> device.getPhoneNumberIdentityRegistrationId();
|
||||
}));
|
||||
.collect(Collectors.toMap(Device::getId,
|
||||
device -> device.getRegistrationId(destinationIdentifier.identityType())));
|
||||
|
||||
try {
|
||||
messageSender.sendMessages(destinationAccount,
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.stream.IntStream;
|
|||
import javax.annotation.Nullable;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
|
||||
import org.whispersystems.textsecuregcm.util.DeviceNameByteArrayAdapter;
|
||||
|
||||
|
@ -207,18 +208,17 @@ public class Device {
|
|||
return getId() == PRIMARY_ID;
|
||||
}
|
||||
|
||||
public int getRegistrationId() {
|
||||
return registrationId;
|
||||
public int getRegistrationId(final IdentityType identityType) {
|
||||
return switch (identityType) {
|
||||
case ACI -> registrationId;
|
||||
case PNI -> phoneNumberIdentityRegistrationId;
|
||||
};
|
||||
}
|
||||
|
||||
public void setRegistrationId(int registrationId) {
|
||||
this.registrationId = registrationId;
|
||||
}
|
||||
|
||||
public int getPhoneNumberIdentityRegistrationId() {
|
||||
return phoneNumberIdentityRegistrationId;
|
||||
}
|
||||
|
||||
public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRegistrationId) {
|
||||
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
|
||||
}
|
||||
|
|
|
@ -214,11 +214,11 @@ class KeysControllerTest {
|
|||
|
||||
AccountsHelper.setupMockUpdate(accounts);
|
||||
|
||||
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
|
||||
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
|
||||
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
|
||||
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
|
||||
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
|
||||
when(sampleDevice.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID);
|
||||
when(sampleDevice2.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2);
|
||||
when(sampleDevice3.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID2);
|
||||
when(sampleDevice4.getRegistrationId(IdentityType.ACI)).thenReturn(SAMPLE_REGISTRATION_ID4);
|
||||
when(sampleDevice.getRegistrationId(IdentityType.PNI)).thenReturn(SAMPLE_PNI_REGISTRATION_ID);
|
||||
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
|
||||
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
|
||||
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
|
||||
|
|
|
@ -1244,15 +1244,15 @@ class MessageControllerTest {
|
|||
|
||||
final Device singleDeviceAccountPrimary = mock(Device.class);
|
||||
when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId);
|
||||
when(singleDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(singleDevicePrimaryRegistrationId);
|
||||
|
||||
final Device multiDeviceAccountPrimary = mock(Device.class);
|
||||
when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId);
|
||||
when(multiDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(multiDevicePrimaryRegistrationId);
|
||||
|
||||
final Device multiDeviceAccountLinked = mock(Device.class);
|
||||
when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
|
||||
when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId);
|
||||
when(multiDeviceAccountLinked.getRegistrationId(IdentityType.ACI)).thenReturn(multiDeviceLinkedRegistrationId);
|
||||
|
||||
final Account singleDeviceAccount = mock(Account.class);
|
||||
when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci);
|
||||
|
@ -1662,15 +1662,15 @@ class MessageControllerTest {
|
|||
|
||||
final Device singleDeviceAccountPrimary = mock(Device.class);
|
||||
when(singleDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(singleDeviceAccountPrimary.getRegistrationId()).thenReturn(singleDevicePrimaryRegistrationId);
|
||||
when(singleDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(singleDevicePrimaryRegistrationId);
|
||||
|
||||
final Device multiDeviceAccountPrimary = mock(Device.class);
|
||||
when(multiDeviceAccountPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(multiDeviceAccountPrimary.getRegistrationId()).thenReturn(multiDevicePrimaryRegistrationId);
|
||||
when(multiDeviceAccountPrimary.getRegistrationId(IdentityType.ACI)).thenReturn(multiDevicePrimaryRegistrationId);
|
||||
|
||||
final Device multiDeviceAccountLinked = mock(Device.class);
|
||||
when(multiDeviceAccountLinked.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
|
||||
when(multiDeviceAccountLinked.getRegistrationId()).thenReturn(multiDeviceLinkedRegistrationId);
|
||||
when(multiDeviceAccountLinked.getRegistrationId(IdentityType.ACI)).thenReturn(multiDeviceLinkedRegistrationId);
|
||||
|
||||
final Account singleDeviceAccount = mock(Account.class);
|
||||
when(singleDeviceAccount.getIdentifier(IdentityType.ACI)).thenReturn(singleDeviceAccountAci);
|
||||
|
|
|
@ -125,13 +125,13 @@ class MessagesGrpcServiceTest extends SimpleBaseGrpcTest<MessagesGrpcService, Me
|
|||
.thenReturn(new SpamCheckResult<>(Optional.empty(), Optional.empty()));
|
||||
|
||||
when(authenticatedDevice.getId()).thenReturn(AUTHENTICATED_DEVICE_ID);
|
||||
when(authenticatedDevice.getRegistrationId()).thenReturn(AUTHENTICATED_REGISTRATION_ID);
|
||||
when(authenticatedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(AUTHENTICATED_REGISTRATION_ID);
|
||||
|
||||
when(linkedDevice.getId()).thenReturn(LINKED_DEVICE_ID);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
|
||||
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(LINKED_DEVICE_REGISTRATION_ID);
|
||||
|
||||
when(secondLinkedDevice.getId()).thenReturn(SECOND_LINKED_DEVICE_ID);
|
||||
when(secondLinkedDevice.getRegistrationId()).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
|
||||
when(secondLinkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(SECOND_LINKED_DEVICE_REGISTRATION_ID);
|
||||
|
||||
when(authenticatedAccount.getUuid()).thenReturn(AUTHENTICATED_ACI);
|
||||
when(authenticatedAccount.getIdentifier(IdentityType.ACI)).thenReturn(AUTHENTICATED_ACI);
|
||||
|
|
|
@ -92,7 +92,7 @@ class MessageSenderTest {
|
|||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
|
||||
|
||||
if (hasPushToken) {
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
@ -140,7 +140,7 @@ class MessageSenderTest {
|
|||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final MismatchedDevicesException mismatchedDevicesException =
|
||||
|
@ -178,7 +178,7 @@ class MessageSenderTest {
|
|||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
if (hasPushToken) {
|
||||
|
@ -230,7 +230,7 @@ class MessageSenderTest {
|
|||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
when(device.getRegistrationId()).thenReturn(registrationId);
|
||||
when(device.getRegistrationId(IdentityType.ACI)).thenReturn(registrationId);
|
||||
when(device.getApnId()).thenReturn("apns-token");
|
||||
|
||||
final SealedSenderMultiRecipientMessage multiRecipientMessage =
|
||||
|
@ -291,13 +291,13 @@ class MessageSenderTest {
|
|||
|
||||
final Device primaryDevice = mock(Device.class);
|
||||
when(primaryDevice.getId()).thenReturn(primaryDeviceId);
|
||||
when(primaryDevice.getRegistrationId()).thenReturn(primaryDeviceAciRegistrationId);
|
||||
when(primaryDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(primaryDevicePniRegistrationId);
|
||||
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(primaryDeviceAciRegistrationId);
|
||||
when(primaryDevice.getRegistrationId(IdentityType.PNI)).thenReturn(primaryDevicePniRegistrationId);
|
||||
|
||||
final Device linkedDevice = mock(Device.class);
|
||||
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceAciRegistrationId);
|
||||
when(linkedDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(linkedDevicePniRegistrationId);
|
||||
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceAciRegistrationId);
|
||||
when(linkedDevice.getRegistrationId(IdentityType.PNI)).thenReturn(linkedDevicePniRegistrationId);
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
|
||||
|
|
|
@ -506,8 +506,8 @@ public class AccountCreationDeletionIntegrationTest {
|
|||
assertEquals(number, account.getNumber());
|
||||
assertEquals(signalAgent, primaryDevice.getUserAgent());
|
||||
assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages());
|
||||
assertEquals(registrationId, primaryDevice.getRegistrationId());
|
||||
assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId());
|
||||
assertEquals(registrationId, primaryDevice.getRegistrationId(IdentityType.ACI));
|
||||
assertEquals(pniRegistrationId, primaryDevice.getRegistrationId(IdentityType.PNI));
|
||||
assertArrayEquals(deviceName, primaryDevice.getName());
|
||||
assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber());
|
||||
assertEquals(deviceCapabilities, primaryDevice.getCapabilities());
|
||||
|
|
|
@ -17,7 +17,6 @@ import java.nio.charset.StandardCharsets;
|
|||
import java.time.Clock;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.OptionalInt;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
@ -232,7 +231,7 @@ class AccountsManagerChangeNumberIntegrationTest {
|
|||
assertEquals(Optional.empty(), accountsManager.findRecentlyDeletedAccountIdentifier(secondPni));
|
||||
|
||||
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
|
||||
assertEquals(rotatedPniRegistrationId, updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
|
||||
assertEquals(rotatedPniRegistrationId, updatedAccount.getPrimaryDevice().getRegistrationId(IdentityType.PNI));
|
||||
|
||||
assertEquals(Optional.of(rotatedSignedPreKey),
|
||||
keysManager.getEcSignedPreKey(updatedAccount.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
|
||||
|
|
|
@ -1012,8 +1012,8 @@ class AccountsManagerTest {
|
|||
assertTrue(device.getAuthTokenHash().verify(password));
|
||||
assertEquals(signalAgent, device.getUserAgent());
|
||||
assertEquals(Collections.emptySet(), device.getCapabilities());
|
||||
assertEquals(aciRegistrationId, device.getRegistrationId());
|
||||
assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId());
|
||||
assertEquals(aciRegistrationId, device.getRegistrationId(IdentityType.ACI));
|
||||
assertEquals(pniRegistrationId, device.getRegistrationId(IdentityType.PNI));
|
||||
assertTrue(device.getFetchesMessages());
|
||||
assertNull(device.getApnId());
|
||||
assertNull(device.getGcmId());
|
||||
|
@ -1265,12 +1265,12 @@ class AccountsManagerTest {
|
|||
assertEquals(oldPni, updatedAccount.getPhoneNumberIdentifier());
|
||||
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
|
||||
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
|
||||
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
|
||||
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, device -> device.getRegistrationId(IdentityType.ACI))));
|
||||
|
||||
// PNI stuff should
|
||||
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
|
||||
assertEquals(newRegistrationIds,
|
||||
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getPhoneNumberIdentityRegistrationId)));
|
||||
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, device -> device.getRegistrationId(IdentityType.PNI))));
|
||||
|
||||
verify(accounts).updateTransactionallyAsync(any(), any());
|
||||
|
||||
|
@ -1509,9 +1509,8 @@ class AccountsManagerTest {
|
|||
final Device parsedDevice = parsedAccount.getPrimaryDevice();
|
||||
|
||||
assertEquals(originalDevice.getId(), parsedDevice.getId());
|
||||
assertEquals(originalDevice.getRegistrationId(), parsedDevice.getRegistrationId());
|
||||
assertEquals(originalDevice.getPhoneNumberIdentityRegistrationId(),
|
||||
parsedDevice.getPhoneNumberIdentityRegistrationId());
|
||||
assertEquals(originalDevice.getRegistrationId(IdentityType.ACI), parsedDevice.getRegistrationId(IdentityType.ACI));
|
||||
assertEquals(originalDevice.getRegistrationId(IdentityType.PNI), parsedDevice.getRegistrationId(IdentityType.PNI));
|
||||
assertEquals(originalDevice.getCapabilities(), parsedDevice.getCapabilities());
|
||||
assertEquals(originalDevice.getFetchesMessages(), parsedDevice.getFetchesMessages());
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
|
|||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
|
||||
import org.whispersystems.textsecuregcm.identity.IdentityType;
|
||||
import org.whispersystems.textsecuregcm.push.MessageSender;
|
||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||
|
@ -141,13 +142,13 @@ public class ChangeNumberManagerTest {
|
|||
|
||||
final Device primaryDevice = mock(Device.class);
|
||||
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
when(primaryDevice.getRegistrationId()).thenReturn(7);
|
||||
when(primaryDevice.getRegistrationId(IdentityType.ACI)).thenReturn(7);
|
||||
|
||||
final Device linkedDevice = mock(Device.class);
|
||||
final byte linkedDeviceId = Device.PRIMARY_ID + 1;
|
||||
final int linkedDeviceRegistrationId = 17;
|
||||
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
|
||||
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
|
||||
when(linkedDevice.getRegistrationId(IdentityType.ACI)).thenReturn(linkedDeviceRegistrationId);
|
||||
|
||||
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
|
||||
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
|
||||
|
@ -401,7 +402,7 @@ public class ChangeNumberManagerTest {
|
|||
for (byte i = 1; i <= 3; i++) {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn(i);
|
||||
when(device.getRegistrationId()).thenReturn((int) i);
|
||||
when(device.getRegistrationId(IdentityType.ACI)).thenReturn((int) i);
|
||||
|
||||
devices.add(device);
|
||||
when(account.getDevice(i)).thenReturn(Optional.of(device));
|
||||
|
|
Loading…
Reference in New Issue