diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 41b44c209..f7ae518bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -171,8 +171,8 @@ public class DeviceController { maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber()); } - if (account.getEnabledDeviceCount() >= maxDeviceLimit) { - throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES); + if (account.getDevices().size() >= maxDeviceLimit) { + throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit); } if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) { @@ -386,8 +386,8 @@ public class DeviceController { maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber()); } - if (account.getEnabledDeviceCount() >= maxDeviceLimit) { - throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES); + if (account.getDevices().size() >= maxDeviceLimit) { + throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit); } final DeviceCapabilities capabilities = accountAttributes.getCapabilities(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index fdde522ef..eff65ab67 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -262,7 +262,7 @@ public class MessageController { OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); } - boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1; + boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice(); // We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify // we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index 1a518cc7d..ec41da64e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -290,16 +290,12 @@ public class Account { return candidateId; } - public int getEnabledDeviceCount() { + public boolean hasEnabledLinkedDevice() { requireNotStale(); - int count = 0; - - for (final Device device : devices) { - if (device.isEnabled()) count++; - } - - return count; + return devices.stream() + .filter(d -> Device.PRIMARY_ID != d.getId()) + .anyMatch(Device::isEnabled); } public void setIdentityKey(final IdentityKey identityKey) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index a4905fa30..63c379f35 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -35,6 +35,7 @@ import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.stream.IntStream; import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; @@ -44,6 +45,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -107,6 +109,9 @@ class DeviceControllerTest { deviceConfiguration, testClock); + @RegisterExtension + public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension(); + private static final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( @@ -630,10 +635,17 @@ class DeviceControllerTest { @Test void maxDevicesTest() { + final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount(); + + final List devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1) + .mapToObj(i -> mock(Device.class)) + .toList(); + when(testAccount.account.getDevices()).thenReturn(devices); + Response response = resources.getJerseyTest() .target("/v1/devices/provisioning/code") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO)) + .header("Authorization", testAccount.getAuthHeader()) .get(); assertEquals(411, response.getStatus()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java index e7bef84ab..d6fcab4a1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java @@ -27,8 +27,12 @@ import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.util.TestClock; @@ -380,4 +384,49 @@ class AccountTest { final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get(); assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value()); } + + @ParameterizedTest + @MethodSource + public void testHasEnabledLinkedDevice(final Account account, final boolean expect) { + assertEquals(expect, account.hasEnabledLinkedDevice()); + } + + static Stream testHasEnabledLinkedDevice() { + final Device enabledPrimary = mock(Device.class); + when(enabledPrimary.isEnabled()).thenReturn(true); + when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID); + + final Device disabledPrimary = mock(Device.class); + when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID); + + final long linked1DeviceId = Device.PRIMARY_ID + 1; + final Device enabledLinked1 = mock(Device.class); + when(enabledLinked1.isEnabled()).thenReturn(true); + when(enabledLinked1.getId()).thenReturn(linked1DeviceId); + + final Device disabledLinked1 = mock(Device.class); + when(disabledLinked1.getId()).thenReturn(linked1DeviceId); + + final long linked2DeviceId = Device.PRIMARY_ID + 2; + final Device enabledLinked2 = mock(Device.class); + when(enabledLinked2.isEnabled()).thenReturn(true); + when(enabledLinked2.getId()).thenReturn(linked2DeviceId); + + final Device disabledLinked2 = mock(Device.class); + when(disabledLinked2.getId()).thenReturn(linked2DeviceId); + + return Stream.of( + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary)), false), + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary, disabledLinked1)), + false), + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", + List.of(enabledPrimary, disabledLinked1, disabledLinked2)), false), + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", + List.of(enabledPrimary, enabledLinked1, disabledLinked2)), true), + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", + List.of(enabledPrimary, disabledLinked1, enabledLinked2)), true), + Arguments.of(AccountsHelper.generateTestAccount("+14155550123", + List.of(disabledLinked2, enabledLinked1, enabledLinked2)), true) + ); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 480f4ea4e..f874f7704 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -123,7 +123,7 @@ public class AccountsHelper { case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing); case "isPniSupported" -> when(updatedAccount.isPniSupported()).thenAnswer(stubbing); case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing); - case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing); + case "hasEnabledLinkedDevice" -> when(updatedAccount.hasEnabledLinkedDevice()).thenAnswer(stubbing); case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing); case "getIdentityKey" -> when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 951af1590..dee0f36bc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -11,15 +11,22 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import com.google.i18n.phonenumbers.Phonenumber; import io.dropwizard.auth.AuthFilter; import io.dropwizard.auth.PolymorphicAuthDynamicFeature; import io.dropwizard.auth.basic.BasicCredentialAuthFilter; import io.dropwizard.auth.basic.BasicCredentials; import java.security.Principal; +import java.util.ArrayList; import java.util.Base64; +import java.util.Collection; +import java.util.HashSet; import java.util.Optional; import java.util.Random; import java.util.UUID; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; @@ -90,6 +97,8 @@ public class AuthHelper { private static SaltedTokenHash DISABLED_CREDENTIALS = mock(SaltedTokenHash.class); private static SaltedTokenHash UNDISCOVERABLE_CREDENTIALS = mock(SaltedTokenHash.class); + private static final Collection EXTENSION_TEST_ACCOUNTS = new HashSet<>(); + public static PolymorphicAuthDynamicFeature getAuthFilter() { when(VALID_CREDENTIALS.verify("foo")).thenReturn(true); when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true); @@ -138,7 +147,7 @@ public class AuthHelper { when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED)); - when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6); + when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true); when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER); when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID); @@ -261,6 +270,11 @@ public class AuthHelper { when(accountsManager.getByE164(number)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); } + + private void teardown(final AccountsManager accountsManager) { + when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()); + when(accountsManager.getByE164(number)).thenReturn(Optional.empty()); + } } private static TestAccount[] generateTestAccounts() { @@ -272,4 +286,35 @@ public class AuthHelper { } return testAccounts; } + + /** + * JUnit 5 extension for creating {@link TestAccount}s scoped to a single test + */ + public static class AuthFilterExtension implements AfterEachCallback { + + public TestAccount createTestAccount() { + final UUID uuid = UUID.randomUUID(); + final String region = new ArrayList<>((PhoneNumberUtil.getInstance().getSupportedRegions())).get( + EXTENSION_TEST_ACCOUNTS.size()); + final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().getExampleNumber(region); + + final TestAccount testAccount = new TestAccount( + PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164), uuid, + "extension-password-" + region); + testAccount.setup(ACCOUNTS_MANAGER); + + EXTENSION_TEST_ACCOUNTS.add(testAccount); + + return testAccount; + } + + @Override + public void afterEach(final ExtensionContext context) { + EXTENSION_TEST_ACCOUNTS.forEach(testAccount -> { + testAccount.teardown(ACCOUNTS_MANAGER); + }); + + EXTENSION_TEST_ACCOUNTS.clear(); + } + } }