Use all devices when checking limit
This commit is contained in:
parent
38b581a231
commit
ba139dddd8
|
@ -171,8 +171,8 @@ public class DeviceController {
|
||||||
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
|
if (account.getDevices().size() >= maxDeviceLimit) {
|
||||||
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
|
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
|
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
|
||||||
|
@ -386,8 +386,8 @@ public class DeviceController {
|
||||||
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
|
if (account.getDevices().size() >= maxDeviceLimit) {
|
||||||
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
|
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
final DeviceCapabilities capabilities = accountAttributes.getCapabilities();
|
final DeviceCapabilities capabilities = accountAttributes.getCapabilities();
|
||||||
|
|
|
@ -262,7 +262,7 @@ public class MessageController {
|
||||||
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
|
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1;
|
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice();
|
||||||
|
|
||||||
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
|
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
|
||||||
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
|
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
|
||||||
|
|
|
@ -290,16 +290,12 @@ public class Account {
|
||||||
return candidateId;
|
return candidateId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getEnabledDeviceCount() {
|
public boolean hasEnabledLinkedDevice() {
|
||||||
requireNotStale();
|
requireNotStale();
|
||||||
|
|
||||||
int count = 0;
|
return devices.stream()
|
||||||
|
.filter(d -> Device.PRIMARY_ID != d.getId())
|
||||||
for (final Device device : devices) {
|
.anyMatch(Device::isEnabled);
|
||||||
if (device.isEnabled()) count++;
|
|
||||||
}
|
|
||||||
|
|
||||||
return count;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setIdentityKey(final IdentityKey identityKey) {
|
public void setIdentityKey(final IdentityKey identityKey) {
|
||||||
|
|
|
@ -35,6 +35,7 @@ import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
import javax.ws.rs.client.Entity;
|
import javax.ws.rs.client.Entity;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
|
@ -44,6 +45,7 @@ import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
@ -107,6 +109,9 @@ class DeviceControllerTest {
|
||||||
deviceConfiguration,
|
deviceConfiguration,
|
||||||
testClock);
|
testClock);
|
||||||
|
|
||||||
|
@RegisterExtension
|
||||||
|
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
|
||||||
|
|
||||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||||
.addProvider(AuthHelper.getAuthFilter())
|
.addProvider(AuthHelper.getAuthFilter())
|
||||||
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
|
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||||
|
@ -630,10 +635,17 @@ class DeviceControllerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void maxDevicesTest() {
|
void maxDevicesTest() {
|
||||||
|
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();
|
||||||
|
|
||||||
|
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
|
||||||
|
.mapToObj(i -> mock(Device.class))
|
||||||
|
.toList();
|
||||||
|
when(testAccount.account.getDevices()).thenReturn(devices);
|
||||||
|
|
||||||
Response response = resources.getJerseyTest()
|
Response response = resources.getJerseyTest()
|
||||||
.target("/v1/devices/provisioning/code")
|
.target("/v1/devices/provisioning/code")
|
||||||
.request()
|
.request()
|
||||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
|
.header("Authorization", testAccount.getAuthHeader())
|
||||||
.get();
|
.get();
|
||||||
|
|
||||||
assertEquals(411, response.getStatus());
|
assertEquals(411, response.getStatus());
|
||||||
|
|
|
@ -27,8 +27,12 @@ import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Stream;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
|
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
|
||||||
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
|
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
|
||||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||||
|
@ -380,4 +384,49 @@ class AccountTest {
|
||||||
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
|
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
|
||||||
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
|
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource
|
||||||
|
public void testHasEnabledLinkedDevice(final Account account, final boolean expect) {
|
||||||
|
assertEquals(expect, account.hasEnabledLinkedDevice());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Stream<Arguments> testHasEnabledLinkedDevice() {
|
||||||
|
final Device enabledPrimary = mock(Device.class);
|
||||||
|
when(enabledPrimary.isEnabled()).thenReturn(true);
|
||||||
|
when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||||
|
|
||||||
|
final Device disabledPrimary = mock(Device.class);
|
||||||
|
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||||
|
|
||||||
|
final long linked1DeviceId = Device.PRIMARY_ID + 1;
|
||||||
|
final Device enabledLinked1 = mock(Device.class);
|
||||||
|
when(enabledLinked1.isEnabled()).thenReturn(true);
|
||||||
|
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);
|
||||||
|
|
||||||
|
final Device disabledLinked1 = mock(Device.class);
|
||||||
|
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);
|
||||||
|
|
||||||
|
final long linked2DeviceId = Device.PRIMARY_ID + 2;
|
||||||
|
final Device enabledLinked2 = mock(Device.class);
|
||||||
|
when(enabledLinked2.isEnabled()).thenReturn(true);
|
||||||
|
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);
|
||||||
|
|
||||||
|
final Device disabledLinked2 = mock(Device.class);
|
||||||
|
when(disabledLinked2.getId()).thenReturn(linked2DeviceId);
|
||||||
|
|
||||||
|
return Stream.of(
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary)), false),
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary, disabledLinked1)),
|
||||||
|
false),
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||||
|
List.of(enabledPrimary, disabledLinked1, disabledLinked2)), false),
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||||
|
List.of(enabledPrimary, enabledLinked1, disabledLinked2)), true),
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||||
|
List.of(enabledPrimary, disabledLinked1, enabledLinked2)), true),
|
||||||
|
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||||
|
List.of(disabledLinked2, enabledLinked1, enabledLinked2)), true)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ public class AccountsHelper {
|
||||||
case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
|
case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
|
||||||
case "isPniSupported" -> when(updatedAccount.isPniSupported()).thenAnswer(stubbing);
|
case "isPniSupported" -> when(updatedAccount.isPniSupported()).thenAnswer(stubbing);
|
||||||
case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing);
|
case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing);
|
||||||
case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
|
case "hasEnabledLinkedDevice" -> when(updatedAccount.hasEnabledLinkedDevice()).thenAnswer(stubbing);
|
||||||
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
|
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
|
||||||
case "getIdentityKey" ->
|
case "getIdentityKey" ->
|
||||||
when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
|
when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
|
||||||
|
|
|
@ -11,15 +11,22 @@ import static org.mockito.Mockito.reset;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableMap;
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||||
|
import com.google.i18n.phonenumbers.Phonenumber;
|
||||||
import io.dropwizard.auth.AuthFilter;
|
import io.dropwizard.auth.AuthFilter;
|
||||||
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
|
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
|
||||||
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
|
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
|
||||||
import io.dropwizard.auth.basic.BasicCredentials;
|
import io.dropwizard.auth.basic.BasicCredentials;
|
||||||
import java.security.Principal;
|
import java.security.Principal;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import org.junit.jupiter.api.extension.AfterEachCallback;
|
||||||
|
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||||
import org.signal.libsignal.protocol.IdentityKey;
|
import org.signal.libsignal.protocol.IdentityKey;
|
||||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||||
|
@ -90,6 +97,8 @@ public class AuthHelper {
|
||||||
private static SaltedTokenHash DISABLED_CREDENTIALS = mock(SaltedTokenHash.class);
|
private static SaltedTokenHash DISABLED_CREDENTIALS = mock(SaltedTokenHash.class);
|
||||||
private static SaltedTokenHash UNDISCOVERABLE_CREDENTIALS = mock(SaltedTokenHash.class);
|
private static SaltedTokenHash UNDISCOVERABLE_CREDENTIALS = mock(SaltedTokenHash.class);
|
||||||
|
|
||||||
|
private static final Collection<TestAccount> EXTENSION_TEST_ACCOUNTS = new HashSet<>();
|
||||||
|
|
||||||
public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter() {
|
public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter() {
|
||||||
when(VALID_CREDENTIALS.verify("foo")).thenReturn(true);
|
when(VALID_CREDENTIALS.verify("foo")).thenReturn(true);
|
||||||
when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true);
|
when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true);
|
||||||
|
@ -138,7 +147,7 @@ public class AuthHelper {
|
||||||
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
|
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
|
||||||
when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
|
when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
|
||||||
|
|
||||||
when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6);
|
when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true);
|
||||||
|
|
||||||
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
|
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
|
||||||
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
|
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
|
||||||
|
@ -261,6 +270,11 @@ public class AuthHelper {
|
||||||
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
|
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
|
||||||
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
|
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void teardown(final AccountsManager accountsManager) {
|
||||||
|
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.empty());
|
||||||
|
when(accountsManager.getByE164(number)).thenReturn(Optional.empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static TestAccount[] generateTestAccounts() {
|
private static TestAccount[] generateTestAccounts() {
|
||||||
|
@ -272,4 +286,35 @@ public class AuthHelper {
|
||||||
}
|
}
|
||||||
return testAccounts;
|
return testAccounts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* JUnit 5 extension for creating {@link TestAccount}s scoped to a single test
|
||||||
|
*/
|
||||||
|
public static class AuthFilterExtension implements AfterEachCallback {
|
||||||
|
|
||||||
|
public TestAccount createTestAccount() {
|
||||||
|
final UUID uuid = UUID.randomUUID();
|
||||||
|
final String region = new ArrayList<>((PhoneNumberUtil.getInstance().getSupportedRegions())).get(
|
||||||
|
EXTENSION_TEST_ACCOUNTS.size());
|
||||||
|
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().getExampleNumber(region);
|
||||||
|
|
||||||
|
final TestAccount testAccount = new TestAccount(
|
||||||
|
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164), uuid,
|
||||||
|
"extension-password-" + region);
|
||||||
|
testAccount.setup(ACCOUNTS_MANAGER);
|
||||||
|
|
||||||
|
EXTENSION_TEST_ACCOUNTS.add(testAccount);
|
||||||
|
|
||||||
|
return testAccount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void afterEach(final ExtensionContext context) {
|
||||||
|
EXTENSION_TEST_ACCOUNTS.forEach(testAccount -> {
|
||||||
|
testAccount.teardown(ACCOUNTS_MANAGER);
|
||||||
|
});
|
||||||
|
|
||||||
|
EXTENSION_TEST_ACCOUNTS.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue