diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManager.java index ec250783a..5c02276ee 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManager.java @@ -5,13 +5,12 @@ package org.whispersystems.textsecuregcm.experiment; -import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicExperimentEnrollmentConfiguration; -import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; - -import java.util.Collections; import java.util.Optional; -import java.util.Set; import java.util.UUID; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicExperimentEnrollmentConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPreRegistrationExperimentEnrollmentConfiguration; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.Util; public class ExperimentEnrollmentManager { @@ -22,24 +21,52 @@ public class ExperimentEnrollmentManager { } public boolean isEnrolled(final UUID accountUuid, final String experimentName) { + final Optional maybeConfiguration = dynamicConfigurationManager .getConfiguration().getExperimentEnrollmentConfiguration(experimentName); - final Set enrolledUuids = maybeConfiguration.map(DynamicExperimentEnrollmentConfiguration::getEnrolledUuids) - .orElse(Collections.emptySet()); + return maybeConfiguration.map(config -> { - final boolean enrolled; + if (config.getEnrolledUuids().contains(accountUuid)) { + return true; + } - if (enrolledUuids.contains(accountUuid)) { - enrolled = true; - } else { - final int threshold = maybeConfiguration.map(DynamicExperimentEnrollmentConfiguration::getEnrollmentPercentage) - .orElse(0); - final int enrollmentHash = ((accountUuid.hashCode() ^ experimentName.hashCode()) & Integer.MAX_VALUE) % 100; + return isEnrolled(accountUuid, config.getEnrollmentPercentage(), experimentName); - enrolled = enrollmentHash < threshold; - } + }).orElse(false); + } - return enrolled; + public boolean isEnrolled(final String e164, final String experimentName) { + + final Optional maybeConfiguration = dynamicConfigurationManager + .getConfiguration().getPreRegistrationEnrollmentConfiguration(experimentName); + + return maybeConfiguration.map(config -> { + + if (config.getEnrolledE164s().contains(e164)) { + return true; + } + + { + final String countryCode = Util.getCountryCode(e164); + + if (config.getIncludedCountryCodes().contains(countryCode)) { + return true; + } + + if (config.getExcludedCountryCodes().contains(countryCode)) { + return false; + } + } + + return isEnrolled(e164, config.getEnrollmentPercentage(), experimentName); + + }).orElse(false); + } + + private boolean isEnrolled(final Object entity, final int enrollmentPercentage, final String experimentName) { + final int enrollmentHash = ((entity.hashCode() ^ experimentName.hashCode()) & Integer.MAX_VALUE) % 100; + + return enrollmentHash < enrollmentPercentage; } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManagerTest.java index 9eecf6565..f07ee3e01 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/experiment/ExperimentEnrollmentManagerTest.java @@ -5,8 +5,7 @@ package org.whispersystems.textsecuregcm.experiment; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -14,53 +13,104 @@ import java.util.Collections; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicExperimentEnrollmentConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPreRegistrationExperimentEnrollmentConfiguration; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; class ExperimentEnrollmentManagerTest { private DynamicExperimentEnrollmentConfiguration experimentEnrollmentConfiguration; + private DynamicPreRegistrationExperimentEnrollmentConfiguration preRegistrationExperimentEnrollmentConfiguration; private ExperimentEnrollmentManager experimentEnrollmentManager; private Account account; private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final String EXPERIMENT_NAME = "test"; + private static final String UUID_EXPERIMENT_NAME = "uuid_test"; + + private static final String E164 = "+12025551212"; + private static final String E164_EXPERIMENT_NAME = "e164_test"; @BeforeEach void setUp() { final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); - experimentEnrollmentConfiguration = mock(DynamicExperimentEnrollmentConfiguration.class); experimentEnrollmentManager = new ExperimentEnrollmentManager(dynamicConfigurationManager); - account = mock(Account.class); + + experimentEnrollmentConfiguration = mock(DynamicExperimentEnrollmentConfiguration.class); + preRegistrationExperimentEnrollmentConfiguration = mock( + DynamicPreRegistrationExperimentEnrollmentConfiguration.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); - when(dynamicConfiguration.getExperimentEnrollmentConfiguration(EXPERIMENT_NAME)) + when(dynamicConfiguration.getExperimentEnrollmentConfiguration(UUID_EXPERIMENT_NAME)) .thenReturn(Optional.of(experimentEnrollmentConfiguration)); + when(dynamicConfiguration.getPreRegistrationEnrollmentConfiguration(E164_EXPERIMENT_NAME)) + .thenReturn(Optional.of(preRegistrationExperimentEnrollmentConfiguration)); + + account = mock(Account.class); when(account.getUuid()).thenReturn(ACCOUNT_UUID); } @Test - void testIsEnrolled() { - assertFalse(experimentEnrollmentManager.isEnrolled(account.getUuid(), EXPERIMENT_NAME)); - assertFalse(experimentEnrollmentManager.isEnrolled(account.getUuid(), EXPERIMENT_NAME + "-unrelated-experiment")); + void testIsEnrolled_UuidExperiment() { + assertFalse(experimentEnrollmentManager.isEnrolled(account.getUuid(), UUID_EXPERIMENT_NAME)); + assertFalse( + experimentEnrollmentManager.isEnrolled(account.getUuid(), UUID_EXPERIMENT_NAME + "-unrelated-experiment")); when(experimentEnrollmentConfiguration.getEnrolledUuids()).thenReturn(Set.of(ACCOUNT_UUID)); - assertTrue(experimentEnrollmentManager.isEnrolled(account.getUuid(), EXPERIMENT_NAME)); + assertTrue(experimentEnrollmentManager.isEnrolled(account.getUuid(), UUID_EXPERIMENT_NAME)); when(experimentEnrollmentConfiguration.getEnrolledUuids()).thenReturn(Collections.emptySet()); when(experimentEnrollmentConfiguration.getEnrollmentPercentage()).thenReturn(0); - assertFalse(experimentEnrollmentManager.isEnrolled(account.getUuid(), EXPERIMENT_NAME)); + assertFalse(experimentEnrollmentManager.isEnrolled(account.getUuid(), UUID_EXPERIMENT_NAME)); when(experimentEnrollmentConfiguration.getEnrollmentPercentage()).thenReturn(100); - assertTrue(experimentEnrollmentManager.isEnrolled(account.getUuid(), EXPERIMENT_NAME)); + assertTrue(experimentEnrollmentManager.isEnrolled(account.getUuid(), UUID_EXPERIMENT_NAME)); + } + + @ParameterizedTest + @MethodSource + void testIsEnrolled_PreRegistrationExperiment(final String e164, final String experimentName, + final Set enrolledE164s, final Set includedCountryCodes, final Set excludedCountryCodes, + final int enrollmentPercentage, + final boolean expectedEnrolled, final String message) { + + when(preRegistrationExperimentEnrollmentConfiguration.getEnrolledE164s()).thenReturn(enrolledE164s); + when(preRegistrationExperimentEnrollmentConfiguration.getEnrollmentPercentage()).thenReturn(enrollmentPercentage); + when(preRegistrationExperimentEnrollmentConfiguration.getIncludedCountryCodes()).thenReturn(includedCountryCodes); + when(preRegistrationExperimentEnrollmentConfiguration.getExcludedCountryCodes()).thenReturn(excludedCountryCodes); + + assertEquals(message, expectedEnrolled, experimentEnrollmentManager.isEnrolled(e164, experimentName)); + } + + static Stream testIsEnrolled_PreRegistrationExperiment() { + return Stream.of( + Arguments.of(E164, E164_EXPERIMENT_NAME, Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), 0, false, "default configuration expects no enrollment"), + Arguments + .of(E164, E164_EXPERIMENT_NAME + "-unrelated-experiment", Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), 0, false, "unknown experiment expects no enrollment"), + Arguments.of(E164, E164_EXPERIMENT_NAME, Set.of(E164), Collections.emptySet(), + Collections.emptySet(), 0, true, "explicitly enrolled E164 overrides 0% rollout"), + Arguments.of(E164, E164_EXPERIMENT_NAME, Set.of(E164), Collections.emptySet(), + Set.of("1"), 0, true, "explicitly enrolled E164 overrides excluded country code"), + Arguments.of(E164, E164_EXPERIMENT_NAME, Collections.emptySet(), Set.of("1"), + Collections.emptySet(), 0, true, "included country code overrides 0% rollout"), + Arguments.of(E164, E164_EXPERIMENT_NAME, Collections.emptySet(), Collections.emptySet(), + Set.of("1"), 100, false, "excluded country code overrides 100% rollout"), + Arguments.of(E164, E164_EXPERIMENT_NAME, Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), 100, true, "enrollment expected for 100% rollout") + ); } }