Add more tests

This commit is contained in:
Ehren Kret 2021-05-18 12:40:37 -05:00
parent c595d9415c
commit 0cd3640f13
3 changed files with 173 additions and 10 deletions

View File

@ -665,7 +665,8 @@ public class MessageController {
}
}
private void validateRegistrationIds(Account account, List<IncomingMessage> messages)
@VisibleForTesting
public static void validateRegistrationIds(Account account, List<IncomingMessage> messages)
throws StaleDevicesException {
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream = messages
.stream()
@ -673,7 +674,8 @@ public class MessageController {
validateRegistrationIds(account, deviceIdAndRegistrationIdStream);
}
private void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
@VisibleForTesting
public static void validateRegistrationIds(Account account, Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream)
throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
@ -689,18 +691,16 @@ public class MessageController {
}
}
private void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage)
@VisibleForTesting
public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage);
}
private void validateCompleteDeviceList(Account account,
Set<Long> messageDeviceIds,
boolean isSyncMessage)
throws MismatchedDevicesException
{
@VisibleForTesting
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import java.util.List;
public class StaleDevicesException extends Throwable {
public class StaleDevicesException extends Exception {
private final List<Long> staleDevices;
public StaleDevicesException(List<Long> staleDevices) {

View File

@ -11,7 +11,9 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
@ -38,6 +40,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.time.Duration;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@ -51,6 +54,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.assertj.core.api.Assertions;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@ -70,7 +74,9 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.controllers.StaleDevicesException;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
@ -98,6 +104,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -624,4 +631,160 @@ class MessageControllerTest {
verify(reportMessageManager).report(senderNumber, messageGuid);
}
static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) {
Account account = mock(Account.class);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i+=2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
Device device = mock(Device.class);
when(device.getRegistrationId()).thenReturn(registrationId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
}
return account;
}
static Collection<Pair<Long, Integer>> deviceAndRegistrationIds(Object... deviceAndRegistrationIds) {
final Collection<Pair<Long, Integer>> result = new HashSet<>(deviceAndRegistrationIds.length);
if (deviceAndRegistrationIds.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
for (int i = 0; i < deviceAndRegistrationIds.length; i += 2) {
if (!(deviceAndRegistrationIds[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceAndRegistrationIds[i + 1] instanceof Integer)) {
throw new IllegalArgumentException("registration id is not instance of integer at index " + (i + 1));
}
Long deviceId = (Long) deviceAndRegistrationIds[i];
Integer registrationId = (Integer) deviceAndRegistrationIds[i + 1];
result.add(new Pair<>(deviceId, registrationId));
}
return result;
}
static Stream<Arguments> validateRegistrationIdsSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndRegId(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
deviceAndRegistrationIds(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 1492),
Set.of(1L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 42),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42),
deviceAndRegistrationIds(1L, 0),
null),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 255),
deviceAndRegistrationIds(1L, 0, 2L, 42),
Set.of(2L)),
arguments(
mockAccountWithDeviceAndRegId(1L, 42, 2L, 256),
deviceAndRegistrationIds(1L, 41, 2L, 257),
Set.of(1L, 2L))
);
}
@ParameterizedTest
@MethodSource("validateRegistrationIdsSource")
void testValidateRegistrationIds(
Account account,
Collection<Pair<Long, Integer>> deviceAndRegistrationIds,
Set<Long> expectedStaleDeviceIds) throws Exception {
if (expectedStaleDeviceIds != null) {
Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> {
MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}).getStaleDevices()).hasSameElementsAs(expectedStaleDeviceIds);
} else {
MessageController.validateRegistrationIds(account, deviceAndRegistrationIds.stream());
}
}
static Account mockAccountWithDeviceAndEnabled(Object... deviceIdAndEnabled) {
Account account = mock(Account.class);
if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even");
}
final Set<Device> devices = new HashSet<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i);
}
if (!(deviceIdAndEnabled[i + 1] instanceof Boolean)) {
throw new IllegalArgumentException("enabled is not instance of boolean at index " + (i + 1));
}
Long deviceId = (Long) deviceIdAndEnabled[i];
Boolean enabled = (Boolean) deviceIdAndEnabled[i + 1];
Device device = mock(Device.class);
when(device.isEnabled()).thenReturn(enabled);
when(device.getId()).thenReturn(deviceId);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
devices.add(device);
}
when(account.getDevices()).thenReturn(devices);
return account;
}
static Stream<Arguments> validateCompleteDeviceListSource() {
return Stream.of(
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L)),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L))
);
}
@ParameterizedTest
@MethodSource("validateCompleteDeviceListSource")
void testValidateCompleteDeviceList(
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageController.validateCompleteDeviceList(account, deviceIds, false));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
}
if (expectedExtraDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageController.validateCompleteDeviceList(account, deviceIds, false);
}
}
}