Assume all accounts have primary devices

This commit is contained in:
Jon Chambers 2023-12-05 17:32:49 -05:00 committed by Jon Chambers
parent 69990c23a5
commit 00e72a30c9
11 changed files with 47 additions and 44 deletions

View File

@ -45,7 +45,7 @@ public class PushChallengeManager {
}
public void sendChallenge(final Account account) throws NotPushRegisteredException {
final Device primaryDevice = account.getPrimaryDevice().orElseThrow(NotPushRegisteredException::new);
final Device primaryDevice = account.getPrimaryDevice();
final byte[] token = new byte[CHALLENGE_TOKEN_LENGTH];
random.nextBytes(token);
@ -86,16 +86,15 @@ public class PushChallengeManager {
} catch (final IllegalArgumentException ignored) {
}
final String platform = account.getPrimaryDevice().map(primaryDevice -> {
if (StringUtils.isNotBlank(primaryDevice.getGcmId())) {
return ClientPlatform.ANDROID.name().toLowerCase();
} else if (StringUtils.isNotBlank(primaryDevice.getApnId())) {
return ClientPlatform.IOS.name().toLowerCase();
} else {
return "unknown";
}
}).orElse("unknown");
final String platform;
if (StringUtils.isNotBlank(account.getPrimaryDevice().getGcmId())) {
platform = ClientPlatform.ANDROID.name().toLowerCase();
} else if (StringUtils.isNotBlank(account.getPrimaryDevice().getApnId())) {
platform = ClientPlatform.IOS.name().toLowerCase();
} else {
platform = "unknown";
}
Metrics.counter(CHALLENGE_ANSWERED_COUNTER_NAME,
PLATFORM_TAG_NAME, platform,

View File

@ -62,7 +62,7 @@ public class PushNotificationManager {
public void sendRateLimitChallengeNotification(final Account destination, final String challengeToken)
throws NotPushRegisteredException {
final Device device = destination.getDevice(Device.PRIMARY_ID).orElseThrow(NotPushRegisteredException::new);
final Device device = destination.getPrimaryDevice();
final Pair<String, PushNotification.TokenType> tokenAndType = getToken(device);
sendNotification(new PushNotification(tokenAndType.first(), tokenAndType.second(),

View File

@ -235,10 +235,11 @@ public class Account {
return devices;
}
public Optional<Device> getPrimaryDevice() {
public Device getPrimaryDevice() {
requireNotStale();
return getDevice(Device.PRIMARY_ID);
return getDevice(Device.PRIMARY_ID)
.orElseThrow(() -> new IllegalStateException("All accounts must have a primary device"));
}
public Optional<Device> getDevice(final byte deviceId) {
@ -256,7 +257,9 @@ public class Account {
public boolean isTransferSupported() {
requireNotStale();
return getPrimaryDevice().map(Device::getCapabilities).map(Device.DeviceCapabilities::transfer).orElse(false);
return Optional.ofNullable(getPrimaryDevice().getCapabilities())
.map(Device.DeviceCapabilities::transfer)
.orElse(false);
}
public boolean isPniSupported() {
@ -278,7 +281,7 @@ public class Account {
public boolean isEnabled() {
requireNotStale();
return getPrimaryDevice().map(Device::isEnabled).orElse(false);
return getPrimaryDevice().isEnabled();
}
public byte getNextDeviceId() {

View File

@ -174,7 +174,7 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();
assert account.getPrimaryDevice().orElseThrow().isEnabled();
assert account.getPrimaryDevice().isEnabled();
initialEnabled.forEach((deviceId, enabled) ->
DevicesHelper.setEnabled(account.getDevice(deviceId).orElseThrow(), enabled));
@ -217,7 +217,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@Test
void testDeviceAdded() {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
assert account.getPrimaryDevice().isEnabled();
final int initialDeviceCount = account.getDevices().size();
@ -241,7 +241,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
assert account.getPrimaryDevice().isEnabled();
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
@ -273,7 +273,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@Test
void testPrimaryDeviceDisabledAndDeviceRemoved() {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
assert account.getPrimaryDevice().isEnabled();
final Set<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
@ -421,7 +421,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@ChangesDeviceEnabledState
public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabled") final boolean enabled) {
final Device device = principal.getAccount().getPrimaryDevice().orElseThrow();
final Device device = principal.getAccount().getPrimaryDevice();
DevicesHelper.setEnabled(device, enabled);
@ -479,7 +479,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@ChangesDeviceEnabledState
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") byte deviceId) {
DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice().orElseThrow(), false);
DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice(), false);
auth.getAccount().removeDevice(deviceId);

View File

@ -165,7 +165,7 @@ class RegistrationControllerTest {
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
@ -288,7 +288,7 @@ class RegistrationControllerTest {
.thenReturn(CompletableFuture.completedFuture(true));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
@ -346,7 +346,7 @@ class RegistrationControllerTest {
expectedStatus = error.getExpectedStatus();
} else {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount);
@ -400,7 +400,7 @@ class RegistrationControllerTest {
when(accountsManager.getByE164(any())).thenReturn(maybeAccount);
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
@ -424,7 +424,7 @@ class RegistrationControllerTest {
SESSION_EXPIRATION_SECONDS))));
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
@ -682,7 +682,7 @@ class RegistrationControllerTest {
final Account account = MockUtils.buildMock(Account.class, a -> {
when(a.getUuid()).thenReturn(accountIdentifier);
when(a.getPhoneNumberIdentifier()).thenReturn(phoneNumberIdentifier);
when(a.getPrimaryDevice()).thenReturn(Optional.of(device));
when(a.getPrimaryDevice()).thenReturn(device);
});
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))

View File

@ -93,7 +93,7 @@ class PushNotificationManagerTest {
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(device.getApnId()).thenReturn(deviceToken);
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
when(account.getPrimaryDevice()).thenReturn(device);
when(apnSender.sendNotification(any()))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));

View File

@ -391,7 +391,7 @@ public class AccountCreationIntegrationTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
final Device primaryDevice = account.getPrimaryDevice().orElseThrow();
final Device primaryDevice = account.getPrimaryDevice();
assertEquals(number, account.getNumber());
assertEquals(signalAgent, primaryDevice.getUserAgent());

View File

@ -188,7 +188,7 @@ class AccountsManagerChangeNumberIntegrationTest {
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, rotatedPniRegistrationId, "test", null, true, new Device.DeviceCapabilities(false, false, false, false));
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes);
account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair));
account.getPrimaryDevice().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair));
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
@ -213,9 +213,9 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(pniIdentityKey, updatedAccount.getIdentityKey(IdentityType.PNI));
assertEquals(OptionalInt.of(rotatedPniRegistrationId),
updatedAccount.getPrimaryDevice().orElseThrow().getPhoneNumberIdentityRegistrationId());
updatedAccount.getPrimaryDevice().getPhoneNumberIdentityRegistrationId());
assertEquals(rotatedSignedPreKey, updatedAccount.getPrimaryDevice().orElseThrow().getSignedPreKey(IdentityType.PNI));
assertEquals(rotatedSignedPreKey, updatedAccount.getPrimaryDevice().getSignedPreKey(IdentityType.PNI));
}
@Test

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
@ -978,7 +979,7 @@ class AccountsManagerTest {
assertThrows(IllegalArgumentException.class, () -> accountsManager.removeDevice(account, Device.PRIMARY_ID));
assertTrue(account.getPrimaryDevice().isPresent());
assertDoesNotThrow(() -> account.getPrimaryDevice());
verify(messagesManager, never()).clear(any(), anyByte());
verify(keysManager, never()).delete(any(), anyByte());
verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte());
@ -1616,8 +1617,8 @@ class AccountsManagerTest {
assertEquals(originalAccount.getDevices().size(), parsedAccount.getDevices().size());
final Device originalDevice = originalAccount.getPrimaryDevice().orElseThrow();
final Device parsedDevice = parsedAccount.getPrimaryDevice().orElseThrow();
final Device originalDevice = originalAccount.getPrimaryDevice();
final Device parsedDevice = parsedAccount.getPrimaryDevice();
assertEquals(originalDevice.getId(), parsedDevice.getId());
assertEquals(originalDevice.getSignedPreKey(IdentityType.ACI), parsedDevice.getSignedPreKey(IdentityType.ACI));

View File

@ -570,7 +570,7 @@ class AccountsTest {
final String deviceName = "device-name";
assertNotEquals(deviceName,
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName());
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().getName());
assertFalse(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())
@ -581,7 +581,7 @@ class AccountsTest {
.build())
.hasItem());
account.getPrimaryDevice().orElseThrow().setName(deviceName);
account.getPrimaryDevice().setName(deviceName);
accounts.updateTransactionallyAsync(account, List.of(TransactWriteItem.builder()
.put(Put.builder()
@ -594,7 +594,7 @@ class AccountsTest {
.build())).toCompletableFuture().join();
assertEquals(deviceName,
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName());
accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().getName());
assertTrue(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder()
.tableName(Tables.CLIENT_RELEASES.tableName())

View File

@ -139,15 +139,15 @@ public class AuthHelper {
when(VALID_DEVICE_3_LINKED.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT.getPrimaryDevice()).thenReturn(VALID_DEVICE);
when(VALID_ACCOUNT_TWO.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(VALID_ACCOUNT_TWO.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(VALID_ACCOUNT_TWO.getPrimaryDevice()).thenReturn(VALID_DEVICE_TWO);
when(DISABLED_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(DISABLED_DEVICE));
when(DISABLED_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(DISABLED_DEVICE));
when(DISABLED_ACCOUNT.getPrimaryDevice()).thenReturn(DISABLED_DEVICE);
when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getPrimaryDevice()).thenReturn(UNDISCOVERABLE_DEVICE);
when(VALID_ACCOUNT_3.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(VALID_DEVICE_3_PRIMARY);
when(VALID_ACCOUNT_3.getDevice((byte) 2)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
when(VALID_ACCOUNT.getDevices()).thenReturn(List.of(VALID_DEVICE));
@ -289,7 +289,7 @@ public class AuthHelper {
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(device.isEnabled()).thenReturn(true);
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device));
when(account.getPrimaryDevice()).thenReturn(Optional.of(device));
when(account.getPrimaryDevice()).thenReturn(device);
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.isEnabled()).thenReturn(true);