Fix registration ID map construction when changing numbers

This commit is contained in:
Jon Chambers 2025-04-15 14:57:28 -04:00 committed by GitHub
parent 2f2ae7cec5
commit 3c40e72d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 54 additions and 28 deletions

View File

@ -693,7 +693,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager,
pushChallengeDynamoDb);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, Clock.systemUTC());
HttpClient currencyClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).connectTimeout(Duration.ofSeconds(10)).build();
FixerClient fixerClient = config.getPaymentsServiceConfiguration().externalClients()

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import java.time.Clock;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -29,12 +30,16 @@ public class ChangeNumberManager {
private static final Logger logger = LoggerFactory.getLogger(ChangeNumberManager.class);
private final MessageSender messageSender;
private final AccountsManager accountsManager;
private final Clock clock;
public ChangeNumberManager(
final MessageSender messageSender,
final AccountsManager accountsManager) {
final AccountsManager accountsManager,
final Clock clock) {
this.messageSender = messageSender;
this.accountsManager = accountsManager;
this.clock = clock;
}
public Account changeNumber(final Account account, final String number,
@ -97,7 +102,7 @@ public class ChangeNumberManager {
final String senderUserAgent) throws MessageTooLargeException, MismatchedDevicesException {
try {
final long serverTimestamp = System.currentTimeMillis();
final long serverTimestamp = clock.millis();
final ServiceIdentifier serviceIdentifier = new AciServiceIdentifier(account.getUuid());
final Map<Byte, Envelope> messagesByDeviceId = deviceMessages.stream()
@ -114,8 +119,8 @@ public class ChangeNumberManager {
.setEphemeral(false)
.build()));
final Map<Byte, Integer> registrationIdsByDeviceId = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, Device::getRegistrationId));
final Map<Byte, Integer> registrationIdsByDeviceId = deviceMessages.stream()
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
messageSender.sendMessages(account,
serviceIdentifier,

View File

@ -9,13 +9,16 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
@ -35,8 +38,10 @@ import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
public class ChangeNumberManagerTest {
private AccountsManager accountsManager;
@ -45,11 +50,13 @@ public class ChangeNumberManagerTest {
private Map<Account, UUID> updatedPhoneNumberIdentifiersByAccount;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@BeforeEach
void setUp() throws Exception {
accountsManager = mock(AccountsManager.class);
messageSender = mock(MessageSender.class);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);
changeNumberManager = new ChangeNumberManager(messageSender, accountsManager, CLOCK);
updatedPhoneNumberIdentifiersByAccount = new HashMap<>();
@ -132,45 +139,59 @@ public class ChangeNumberManagerTest {
when(account.getUuid()).thenReturn(aci);
when(account.getPhoneNumberIdentifier()).thenReturn(pni);
final Device d2 = mock(Device.class);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(primaryDevice.getRegistrationId()).thenReturn(7);
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final Device linkedDevice = mock(Device.class);
final byte linkedDeviceId = Device.PRIMARY_ID + 1;
final int linkedDeviceRegistrationId = 17;
when(linkedDevice.getId()).thenReturn(linkedDeviceId);
when(linkedDevice.getRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevice(linkedDeviceId)).thenReturn(Optional.of(linkedDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDevice));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
linkedDeviceId, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, linkedDeviceId, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.type()).thenReturn(1);
when(msg.destinationDeviceId()).thenReturn(linkedDeviceId);
when(msg.destinationRegistrationId()).thenReturn(linkedDeviceRegistrationId);
when(msg.content()).thenReturn(new byte[]{1});
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds, null);
verify(accountsManager).changeNumber(account, changedE164, pniIdentityKey, prekeys, null, registrationIds);
@SuppressWarnings("unchecked") final ArgumentCaptor<Map<Byte, MessageProtos.Envelope>> envelopeCaptor =
ArgumentCaptor.forClass(Map.class);
final MessageProtos.Envelope expectedEnvelope = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(msg.type()))
.setClientTimestamp(CLOCK.millis())
.setServerTimestamp(CLOCK.millis())
.setDestinationServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setContent(ByteString.copyFrom(msg.content()))
.setSourceServiceId(new AciServiceIdentifier(aci).toServiceIdentifierString())
.setSourceDevice(Device.PRIMARY_ID)
.setUpdatedPni(updatedPhoneNumberIdentifiersByAccount.get(account).toString())
.setUrgent(true)
.setEphemeral(false)
.build();
verify(messageSender).sendMessages(any(), any(), envelopeCaptor.capture(), any(), any(), any());
assertEquals(1, envelopeCaptor.getValue().size());
assertEquals(Set.of(deviceId2), envelopeCaptor.getValue().keySet());
final MessageProtos.Envelope envelope = envelopeCaptor.getValue().get(deviceId2);
assertEquals(aci, UUID.fromString(envelope.getDestinationServiceId()));
assertEquals(aci, UUID.fromString(envelope.getSourceServiceId()));
assertEquals(Device.PRIMARY_ID, envelope.getSourceDevice());
assertEquals(updatedPhoneNumberIdentifiersByAccount.get(account), UUID.fromString(envelope.getUpdatedPni()));
verify(messageSender).sendMessages(argThat(a -> a.getUuid().equals(aci)),
eq(new AciServiceIdentifier(aci)),
eq(Map.of(linkedDeviceId, expectedEnvelope)),
eq(Map.of(linkedDeviceId, linkedDeviceRegistrationId)),
eq(Optional.of(Device.PRIMARY_ID)),
any());
}
@Test
void changeNumberSetPrimaryDevicePrekeyPqAndSendMessages() throws Exception {
final String originalE164 = "+18005551234";