Fix registration ID map construction when changing numbers

This commit is contained in:
Jon Chambers 2025-04-15 14:57:28 -04:00 committed by GitHub
parent 2f2ae7cec5
commit 3c40e72d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 28 deletions

View File

@ -693,7 +693,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb); pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build(); HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients() FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -29,12 +30,16 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class); private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender; private final MessageSender messageSender;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager( public ChangeNumberManager(
final MessageSender messageSender, final MessageSender messageSender,
final AccountsManager accountsManager) { final AccountsManager accountsManager,
final Clock clock) {
this.messageSender = messageSender; this.messageSender = messageSender;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.clock = clock;
} }
public Account changeNumber(final Account account, final String number, public Account changeNumber(final Account account, final String number,
@ -97,7 +102,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException { final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try { try {
final long serverTimestamp = System.currentTimeMillis(); final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid()); final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream() final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -114,8 +119,8 @@ public class ChangeNumberManager {
.setEphemeral(false) .setEphemeral(false)
.build())); .build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream() final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId)); .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account, messageSender.sendMessages(account,
serviceIdentifier, serviceIdentifier,

View File

@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -45,11 +50,13 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount; private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach @BeforeEach
void setUp() throws Exception { void setUp() throws Exception {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class); messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>(); updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@ -132,45 +139,59 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni); when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class); final Device primaryDevice = mock(Device.class);
final byte deviceId2 = 2; when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(d2.getId()).thenReturn(deviceId2); when(primaryDevice.getRegistrationId()).thenReturn(7);
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); final Device linkedDevice = mock(Device.class);
when(account.getDevices()).thenReturn(List.of(d2)); final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID, final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair), KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19);
final IncomingMessage msg = mock(IncomingMessage.class); final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.type()).thenReturn(1);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1}); when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds); verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor = final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder()
ArgumentCaptor.forClass(Map.class); .setType(MessageProtos.Envelope.Type.forNumber(msg.type()))
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any()); verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)),
eq(new AciServiceIdentifier(aci)),
assertEquals(1, envelopeCaptor.getValue().size()); eq(Map.of(linkedDeviceId, expectedEnvelope)),
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet()); eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)),
eq(Optional.of(Device.PRIMARY_ID)),
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2); any());
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
} }
@Test @Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception { void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234"; final String originalE164 = "+18005551234";