diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2a5372b77..3fb4fa2c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -693,7 +693,7 @@ public class WhisperServerService extends Application messagesByDeviceId = deviceMessages.stream() @@ -114,8 +119,8 @@ public class ChangeNumberManager { .setEphemeral(false) .build())); - final Map registrationIdsByDeviceId = account.getDevices().stream() - .collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); + final Map registrationIdsByDeviceId = deviceMessages.stream() + .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); messageSender.sendMessages(account, serviceIdentifier, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 9830a9e35..cfb0bcbce 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.protobuf.ByteString; import java.nio.charset.StandardCharsets; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.TestClock; public class ChangeNumberManagerTest { private AccountsManager accountsManager; @@ -45,11 +50,13 @@ public class ChangeNumberManagerTest { private Map updatedPhoneNumberIdentifiersByAccount; + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + @BeforeEach void setUp() throws Exception { accountsManager = mock(AccountsManager.class); messageSender = mock(MessageSender.class); - changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); + changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK); updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); @@ -132,45 +139,59 @@ public class ChangeNumberManagerTest { when(account.getUuid()).thenReturn(aci); when(account.getPhoneNumberIdentifier()).thenReturn(pni); - final Device d2 = mock(Device.class); - final byte deviceId2 = 2; - when(d2.getId()).thenReturn(deviceId2); + final Device primaryDevice = mock(Device.class); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(primaryDevice.getRegistrationId()).thenReturn(7); - when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); - when(account.getDevices()).thenReturn(List.of(d2)); + final Device linkedDevice = mock(Device.class); + final byte linkedDeviceId = Device.PRIMARY_ID + 1; + final int linkedDeviceRegistrationId = 17; + when(linkedDevice.getId()).thenReturn(linkedDeviceId); + when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId); + + when(account.getDevice(anyByte())).thenReturn(Optional.empty()); + when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice)); + when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice)); + when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final Map prekeys = Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), - deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); + linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(deviceId2); + when(msg.type()).thenReturn(1); + when(msg.destinationDeviceId()).thenReturn(linkedDeviceId); + when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId); when(msg.content()).thenReturn(new byte[]{1}); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); - @SuppressWarnings("unchecked") final ArgumentCaptor> envelopeCaptor = - ArgumentCaptor.forClass(Map.class); + final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder() + .setType(MessageProtos.Envelope.Type.forNumber(msg.type())) + .setClientTimestamp(CLOCK.millis()) + .setServerTimestamp(CLOCK.millis()) + .setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString()) + .setContent(ByteString.copyFrom(msg.content())) + .setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString()) + .setSourceDevice(Device.PRIMARY_ID) + .setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString()) + .setUrgent(true) + .setEphemeral(false) + .build(); - verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); - - assertEquals(1, envelopeCaptor.getValue().size()); - assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); - - final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); - - assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId())); - assertEquals(aci, UUID.fromString(envelope.getSourceServiceId())); - assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice()); - assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni())); + verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)), + eq(new AciServiceIdentifier(aci)), + eq(Map.of(linkedDeviceId, expectedEnvelope)), + eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)), + eq(Optional.of(Device.PRIMARY_ID)), + any()); } - @Test void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { final String originalE164 = "+18005551234";