Use all devices when checking limit
This commit is contained in:
parent
38b581a231
commit
ba139dddd8
|
@ -171,8 +171,8 @@ public class DeviceController {
|
|||
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
||||
}
|
||||
|
||||
if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
|
||||
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
|
||||
if (account.getDevices().size() >= maxDeviceLimit) {
|
||||
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
|
||||
}
|
||||
|
||||
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
|
||||
|
@ -386,8 +386,8 @@ public class DeviceController {
|
|||
maxDeviceLimit = maxDeviceConfiguration.get(account.getNumber());
|
||||
}
|
||||
|
||||
if (account.getEnabledDeviceCount() >= maxDeviceLimit) {
|
||||
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
|
||||
if (account.getDevices().size() >= maxDeviceLimit) {
|
||||
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
|
||||
}
|
||||
|
||||
final DeviceCapabilities capabilities = accountAttributes.getCapabilities();
|
||||
|
|
|
@ -262,7 +262,7 @@ public class MessageController {
|
|||
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
|
||||
}
|
||||
|
||||
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1;
|
||||
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice();
|
||||
|
||||
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
|
||||
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
|
||||
|
|
|
@ -290,16 +290,12 @@ public class Account {
|
|||
return candidateId;
|
||||
}
|
||||
|
||||
public int getEnabledDeviceCount() {
|
||||
public boolean hasEnabledLinkedDevice() {
|
||||
requireNotStale();
|
||||
|
||||
int count = 0;
|
||||
|
||||
for (final Device device : devices) {
|
||||
if (device.isEnabled()) count++;
|
||||
}
|
||||
|
||||
return count;
|
||||
return devices.stream()
|
||||
.filter(d -> Device.PRIMARY_ID != d.getId())
|
||||
.anyMatch(Device::isEnabled);
|
||||
}
|
||||
|
||||
public void setIdentityKey(final IdentityKey identityKey) {
|
||||
|
|
|
@ -35,6 +35,7 @@ import java.util.Map;
|
|||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.IntStream;
|
||||
import java.util.stream.Stream;
|
||||
import javax.ws.rs.client.Entity;
|
||||
import javax.ws.rs.core.MediaType;
|
||||
|
@ -44,6 +45,7 @@ import org.junit.jupiter.api.AfterEach;
|
|||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
@ -107,6 +109,9 @@ class DeviceControllerTest {
|
|||
deviceConfiguration,
|
||||
testClock);
|
||||
|
||||
@RegisterExtension
|
||||
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
|
||||
|
||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||
.addProvider(AuthHelper.getAuthFilter())
|
||||
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||
|
@ -630,10 +635,17 @@ class DeviceControllerTest {
|
|||
|
||||
@Test
|
||||
void maxDevicesTest() {
|
||||
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();
|
||||
|
||||
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
|
||||
.mapToObj(i -> mock(Device.class))
|
||||
.toList();
|
||||
when(testAccount.account.getDevices()).thenReturn(devices);
|
||||
|
||||
Response response = resources.getJerseyTest()
|
||||
.target("/v1/devices/provisioning/code")
|
||||
.request()
|
||||
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
|
||||
.header("Authorization", testAccount.getAuthHeader())
|
||||
.get();
|
||||
|
||||
assertEquals(411, response.getStatus());
|
||||
|
|
|
@ -27,8 +27,12 @@ import java.util.List;
|
|||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Stream;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
|
||||
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
|
||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||
|
@ -380,4 +384,49 @@ class AccountTest {
|
|||
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
|
||||
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
public void testHasEnabledLinkedDevice(final Account account, final boolean expect) {
|
||||
assertEquals(expect, account.hasEnabledLinkedDevice());
|
||||
}
|
||||
|
||||
static Stream<Arguments> testHasEnabledLinkedDevice() {
|
||||
final Device enabledPrimary = mock(Device.class);
|
||||
when(enabledPrimary.isEnabled()).thenReturn(true);
|
||||
when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
final Device disabledPrimary = mock(Device.class);
|
||||
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
final long linked1DeviceId = Device.PRIMARY_ID + 1;
|
||||
final Device enabledLinked1 = mock(Device.class);
|
||||
when(enabledLinked1.isEnabled()).thenReturn(true);
|
||||
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);
|
||||
|
||||
final Device disabledLinked1 = mock(Device.class);
|
||||
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);
|
||||
|
||||
final long linked2DeviceId = Device.PRIMARY_ID + 2;
|
||||
final Device enabledLinked2 = mock(Device.class);
|
||||
when(enabledLinked2.isEnabled()).thenReturn(true);
|
||||
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);
|
||||
|
||||
final Device disabledLinked2 = mock(Device.class);
|
||||
when(disabledLinked2.getId()).thenReturn(linked2DeviceId);
|
||||
|
||||
return Stream.of(
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary)), false),
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary, disabledLinked1)),
|
||||
false),
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||
List.of(enabledPrimary, disabledLinked1, disabledLinked2)), false),
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||
List.of(enabledPrimary, enabledLinked1, disabledLinked2)), true),
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||
List.of(enabledPrimary, disabledLinked1, enabledLinked2)), true),
|
||||
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
|
||||
List.of(disabledLinked2, enabledLinked1, enabledLinked2)), true)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -123,7 +123,7 @@ public class AccountsHelper {
|
|||
case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
|
||||
case "isPniSupported" -> when(updatedAccount.isPniSupported()).thenAnswer(stubbing);
|
||||
case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing);
|
||||
case "getEnabledDeviceCount" -> when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
|
||||
case "hasEnabledLinkedDevice" -> when(updatedAccount.hasEnabledLinkedDevice()).thenAnswer(stubbing);
|
||||
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
|
||||
case "getIdentityKey" ->
|
||||
when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
|
||||
|
|
|
@ -11,15 +11,22 @@ import static org.mockito.Mockito.reset;
|
|||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||
import com.google.i18n.phonenumbers.Phonenumber;
|
||||
import io.dropwizard.auth.AuthFilter;
|
||||
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
|
||||
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
|
||||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import java.security.Principal;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.extension.AfterEachCallback;
|
||||
import org.junit.jupiter.api.extension.ExtensionContext;
|
||||
import org.signal.libsignal.protocol.IdentityKey;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
|
@ -90,6 +97,8 @@ public class AuthHelper {
|
|||
private static SaltedTokenHash DISABLED_CREDENTIALS = mock(SaltedTokenHash.class);
|
||||
private static SaltedTokenHash UNDISCOVERABLE_CREDENTIALS = mock(SaltedTokenHash.class);
|
||||
|
||||
private static final Collection<TestAccount> EXTENSION_TEST_ACCOUNTS = new HashSet<>();
|
||||
|
||||
public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter() {
|
||||
when(VALID_CREDENTIALS.verify("foo")).thenReturn(true);
|
||||
when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true);
|
||||
|
@ -138,7 +147,7 @@ public class AuthHelper {
|
|||
when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY));
|
||||
when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED));
|
||||
|
||||
when(VALID_ACCOUNT_TWO.getEnabledDeviceCount()).thenReturn(6);
|
||||
when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true);
|
||||
|
||||
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
|
||||
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
|
||||
|
@ -261,6 +270,11 @@ public class AuthHelper {
|
|||
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
|
||||
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
|
||||
}
|
||||
|
||||
private void teardown(final AccountsManager accountsManager) {
|
||||
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.empty());
|
||||
when(accountsManager.getByE164(number)).thenReturn(Optional.empty());
|
||||
}
|
||||
}
|
||||
|
||||
private static TestAccount[] generateTestAccounts() {
|
||||
|
@ -272,4 +286,35 @@ public class AuthHelper {
|
|||
}
|
||||
return testAccounts;
|
||||
}
|
||||
|
||||
/**
|
||||
* JUnit 5 extension for creating {@link TestAccount}s scoped to a single test
|
||||
*/
|
||||
public static class AuthFilterExtension implements AfterEachCallback {
|
||||
|
||||
public TestAccount createTestAccount() {
|
||||
final UUID uuid = UUID.randomUUID();
|
||||
final String region = new ArrayList<>((PhoneNumberUtil.getInstance().getSupportedRegions())).get(
|
||||
EXTENSION_TEST_ACCOUNTS.size());
|
||||
final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().getExampleNumber(region);
|
||||
|
||||
final TestAccount testAccount = new TestAccount(
|
||||
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164), uuid,
|
||||
"extension-password-" + region);
|
||||
testAccount.setup(ACCOUNTS_MANAGER);
|
||||
|
||||
EXTENSION_TEST_ACCOUNTS.add(testAccount);
|
||||
|
||||
return testAccount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterEach(final ExtensionContext context) {
|
||||
EXTENSION_TEST_ACCOUNTS.forEach(testAccount -> {
|
||||
testAccount.teardown(ACCOUNTS_MANAGER);
|
||||
});
|
||||
|
||||
EXTENSION_TEST_ACCOUNTS.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue