diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index ae21def9b..3ac14bfdf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -665,7 +665,8 @@ public class MessageController { } } - private void validateRegistrationIds(Account account, List messages) + @VisibleForTesting + public static void validateRegistrationIds(Account account, List messages) throws StaleDevicesException { final Stream> deviceIdAndRegistrationIdStream = messages .stream() @@ -673,7 +674,8 @@ public class MessageController { validateRegistrationIds(account, deviceIdAndRegistrationIdStream); } - private void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) + @VisibleForTesting + public static void validateRegistrationIds(Account account, Stream> deviceIdAndRegistrationIdStream) throws StaleDevicesException { final List staleDevices = deviceIdAndRegistrationIdStream .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) @@ -689,18 +691,16 @@ public class MessageController { } } - private void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage) + @VisibleForTesting + public static void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage) throws MismatchedDevicesException { Set messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()); validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage); } - - private void validateCompleteDeviceList(Account account, - Set messageDeviceIds, - boolean isSyncMessage) - throws MismatchedDevicesException - { + @VisibleForTesting + public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage) + throws MismatchedDevicesException { Set accountDeviceIds = new HashSet<>(); List missingDeviceIds = new LinkedList<>(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java index c65a304c0..7e914f176 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.controllers; import java.util.List; -public class StaleDevicesException extends Throwable { +public class StaleDevicesException extends Exception { private final List staleDevices; public StaleDevicesException(List staleDevices) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index 5dc1af0d5..be5495ea3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -11,7 +11,9 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -38,6 +40,7 @@ import io.dropwizard.testing.junit5.ResourceExtension; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.time.Duration; import java.util.Base64; +import java.util.Collection; import java.util.HashSet; import java.util.LinkedList; import java.util.List; @@ -51,6 +54,7 @@ import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.assertj.core.api.Assertions; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -70,7 +74,9 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.controllers.StaleDevicesException; import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; @@ -98,6 +104,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; @ExtendWith(DropwizardExtensionsSupport.class) @@ -624,4 +631,160 @@ class MessageControllerTest { verify(reportMessageManager).report(senderNumber, messageGuid); } + + static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) { + Account account = mock(Account.class); + if (deviceAndRegistrationIds.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) { + if (!(deviceAndRegistrationIds[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { + throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); + } + Long deviceId = (Long) deviceAndRegistrationIds[i]; + Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; + Device device = mock(Device.class); + when(device.getRegistrationId()).thenReturn(registrationId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + } + return account; + } + + static Collection> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) { + final Collection> result = new HashSet<>(deviceAndRegistrationIds.length); + if (deviceAndRegistrationIds.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) { + if (!(deviceAndRegistrationIds[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) { + throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1)); + } + Long deviceId = (Long) deviceAndRegistrationIds[i]; + Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1]; + result.add(new Pair<>(deviceId, registrationId)); + } + return result; + } + + static Stream validateRegistrationIdsSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 1492), + Set.of(1L)), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 42), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42), + deviceAndRegistrationIds(1L, 0), + null), + arguments( + mockAccountWithDeviceAndRegId(1L, 42, 2L, 255), + deviceAndRegistrationIds(1L, 0, 2L, 42), + Set.of(2L)), + arguments( + mockAccountWithDeviceAndRegId(1L, 42, 2L, 256), + deviceAndRegistrationIds(1L, 41, 2L, 257), + Set.of(1L, 2L)) + ); + } + + @ParameterizedTest + @MethodSource("validateRegistrationIdsSource") + void testValidateRegistrationIds( + Account account, + Collection> deviceAndRegistrationIds, + Set expectedStaleDeviceIds) throws Exception { + if (expectedStaleDeviceIds != null) { + Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> { + MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); + }).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds); + } else { + MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream()); + } + } + + static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) { + Account account = mock(Account.class); + if (deviceIdAndEnabled.length % 2 != 0) { + throw new IllegalArgumentException("invalid number of arguments specified; must be even"); + } + final Set devices = new HashSet<>(deviceIdAndEnabled.length / 2); + for (int i = 0; i < deviceIdAndEnabled.length; i+=2) { + if (!(deviceIdAndEnabled[i] instanceof Long)) { + throw new IllegalArgumentException("device id is not instance of long at index " + i); + } + if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) { + throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1)); + } + Long deviceId = (Long) deviceIdAndEnabled[i]; + Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1]; + Device device = mock(Device.class); + when(device.isEnabled()).thenReturn(enabled); + when(device.getId()).thenReturn(deviceId); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + devices.add(device); + } + when(account.getDevices()).thenReturn(devices); + return account; + } + + static Stream validateCompleteDeviceListSource() { + return Stream.of( + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 3L), + null, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 2L, 3L), + null, + Set.of(2L)), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L), + Set.of(3L), + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L, 2L), + Set.of(3L), + Set.of(2L)) + ); + } + + @ParameterizedTest + @MethodSource("validateCompleteDeviceListSource") + void testValidateCompleteDeviceList( + Account account, + Set deviceIds, + Collection expectedMissingDeviceIds, + Collection expectedExtraDeviceIds) throws Exception { + if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { + final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, + () -> MessageController.validateCompleteDeviceList(account, deviceIds, false)); + if (expectedMissingDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) + .hasSameElementsAs(expectedMissingDeviceIds); + } + if (expectedExtraDeviceIds != null) { + Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); + } + } else { + MessageController.validateCompleteDeviceList(account, deviceIds, false); + } + } }