Add hashKey to RemoteConfig

This allows the percentages for different entries in remote config to
be aligned so one remote config can be a subset of another.
This commit is contained in:
Ehren Kret 2020-05-13 11:08:22 -07:00
parent 674e63cd3e
commit 7da9e88c0b
9 changed files with 200 additions and 56 deletions

View File

@ -23,6 +23,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
@ -50,7 +51,8 @@ public class RemoteConfigController {
MessageDigest digest = MessageDigest.getInstance("SHA1");
return new UserRemoteConfigList(remoteConfigsManager.getAll().stream().map(config -> {
boolean inBucket = isInBucket(digest, account.getUuid(), config.getName().getBytes(), config.getPercentage(), config.getUuids());
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) : config.getName().getBytes(StandardCharsets.UTF_8);
boolean inBucket = isInBucket(digest, account.getUuid(), hashKey, config.getPercentage(), config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket, inBucket ? config.getValue() : config.getDefaultValue());
}).collect(Collectors.toList()));
} catch (NoSuchAlgorithmException e) {
@ -82,7 +84,7 @@ public class RemoteConfigController {
}
@VisibleForTesting
public static boolean isInBucket(MessageDigest digest, UUID uid, byte[] configName, int configPercentage, Set<UUID> uuidsInBucket) {
public static boolean isInBucket(MessageDigest digest, UUID uid, byte[] hashKey, int configPercentage, Set<UUID> uuidsInBucket) {
if (uuidsInBucket.contains(uid)) return true;
ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
@ -91,7 +93,7 @@ public class RemoteConfigController {
digest.update(bb.array());
byte[] hash = digest.digest(configName);
byte[] hash = digest.digest(hashKey);
int bucket = (int)(Math.abs(Conversions.byteArrayToLong(hash)) % 100);
return bucket < configPercentage;

View File

@ -32,14 +32,18 @@ public class RemoteConfig {
@JsonProperty
private String value;
@JsonProperty
private String hashKey;
public RemoteConfig() {}
public RemoteConfig(String name, int percentage, Set<UUID> uuids, String defaultValue, String value) {
public RemoteConfig(String name, int percentage, Set<UUID> uuids, String defaultValue, String value, String hashKey) {
this.name = name;
this.percentage = percentage;
this.uuids = uuids;
this.defaultValue = defaultValue;
this.value = value;
this.hashKey = hashKey;
}
public int getPercentage() {
@ -61,4 +65,8 @@ public class RemoteConfig {
public String getValue() {
return value;
}
public String getHashKey() {
return hashKey;
}
}

View File

@ -19,6 +19,7 @@ public class RemoteConfigs {
public static final String UUIDS = "uuids";
public static final String DEFAULT_VALUE = "default_value";
public static final String VALUE = "value";
public static final String HASH_KEY = "hash_key";
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer setTimer = metricRegistry.timer(name(Accounts.class, "set" ));
@ -36,12 +37,13 @@ public class RemoteConfigs {
public void set(RemoteConfig remoteConfig) {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = setTimer.time()) {
handle.createUpdate("INSERT INTO remote_config (" + NAME + ", " + PERCENTAGE + ", " + UUIDS + ", " + DEFAULT_VALUE + ", " + VALUE + ") VALUES (:name, :percentage, :uuids, :default_value, :value) ON CONFLICT(" + NAME + ") DO UPDATE SET " + PERCENTAGE + " = EXCLUDED." + PERCENTAGE + ", " + UUIDS + " = EXCLUDED." + UUIDS + ", " + DEFAULT_VALUE + " = EXCLUDED." + DEFAULT_VALUE + ", " + VALUE + " = EXCLUDED." + VALUE)
handle.createUpdate("INSERT INTO remote_config (" + NAME + ", " + PERCENTAGE + ", " + UUIDS + ", " + DEFAULT_VALUE + ", " + VALUE + ", " + HASH_KEY + ") VALUES (:name, :percentage, :uuids, :default_value, :value, :hash_key) ON CONFLICT(" + NAME + ") DO UPDATE SET " + PERCENTAGE + " = EXCLUDED." + PERCENTAGE + ", " + UUIDS + " = EXCLUDED." + UUIDS + ", " + DEFAULT_VALUE + " = EXCLUDED." + DEFAULT_VALUE + ", " + VALUE + " = EXCLUDED." + VALUE + ", " + HASH_KEY + " = EXCLUDED." + HASH_KEY)
.bind("name", remoteConfig.getName())
.bind("percentage", remoteConfig.getPercentage())
.bind("uuids", remoteConfig.getUuids().toArray(new UUID[0]))
.bind("default_value", remoteConfig.getDefaultValue())
.bind("value", remoteConfig.getValue())
.bind("hash_key", remoteConfig.getHashKey())
.execute();
}
}));

View File

@ -20,6 +20,7 @@ public class RemoteConfigRowMapper implements RowMapper<RemoteConfig> {
rs.getInt(RemoteConfigs.PERCENTAGE),
new HashSet<>(Arrays.asList((UUID[])rs.getArray(RemoteConfigs.UUIDS).getArray())),
rs.getString(RemoteConfigs.DEFAULT_VALUE),
rs.getString(RemoteConfigs.VALUE));
rs.getString(RemoteConfigs.VALUE),
rs.getString(RemoteConfigs.HASH_KEY));
}
}

View File

@ -318,4 +318,10 @@
</addColumn>
</changeSet>
<changeSet id="18" author="ehren">
<addColumn tableName="remote_config">
<column name="hash_key" type="text"/>
</addColumn>
</changeSet>
</databaseChangeLog>

View File

@ -10,6 +10,7 @@ import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.controllers.RemoteConfigController;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
@ -22,14 +23,13 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
@ -56,13 +56,16 @@ public class RemoteConfigControllerTest {
@Before
public void setup() {
when(remoteConfigsManager.getAll()).thenReturn(new LinkedList<>() {{
add(new RemoteConfig("android.stickers", 25, Set.of(AuthHelper.DISABLED_UUID, AuthHelper.INVALID_UUID), null, null));
add(new RemoteConfig("ios.stickers", 50, Set.of(), null, null));
add(new RemoteConfig("always.true", 100, Set.of(), null, null));
add(new RemoteConfig("only.special", 0, Set.of(AuthHelper.VALID_UUID), null, null));
add(new RemoteConfig("value.always.true", 100, Set.of(), "foo", "bar"));
add(new RemoteConfig("value.only.special", 0, Set.of(AuthHelper.VALID_UUID), "abc", "xyz"));
add(new RemoteConfig("value.always.false", 0, Set.of(), "red", "green"));
add(new RemoteConfig("android.stickers", 25, Set.of(AuthHelper.DISABLED_UUID, AuthHelper.INVALID_UUID), null, null, null));
add(new RemoteConfig("ios.stickers", 50, Set.of(), null, null, null));
add(new RemoteConfig("always.true", 100, Set.of(), null, null, null));
add(new RemoteConfig("only.special", 0, Set.of(AuthHelper.VALID_UUID), null, null, null));
add(new RemoteConfig("value.always.true", 100, Set.of(), "foo", "bar", null));
add(new RemoteConfig("value.only.special", 0, Set.of(AuthHelper.VALID_UUID), "abc", "xyz", null));
add(new RemoteConfig("value.always.false", 0, Set.of(), "red", "green", null));
add(new RemoteConfig("linked.config.0", 50, Set.of(), null, null, null));
add(new RemoteConfig("linked.config.1", 50, Set.of(), null, null, "linked.config.0"));
add(new RemoteConfig("unlinked.config", 50, Set.of(), null, null, null));
}});
}
@ -76,7 +79,7 @@ public class RemoteConfigControllerTest {
verify(remoteConfigsManager, times(1)).getAll();
assertThat(configuration.getConfig().size()).isEqualTo(7);
assertThat(configuration.getConfig()).hasSize(10);
assertThat(configuration.getConfig().get(0).getName()).isEqualTo("android.stickers");
assertThat(configuration.getConfig().get(1).getName()).isEqualTo("ios.stickers");
assertThat(configuration.getConfig().get(2).getName()).isEqualTo("always.true");
@ -94,6 +97,9 @@ public class RemoteConfigControllerTest {
assertThat(configuration.getConfig().get(6).getName()).isEqualTo("value.always.false");
assertThat(configuration.getConfig().get(6).isEnabled()).isEqualTo(false);
assertThat(configuration.getConfig().get(6).getValue()).isEqualTo("red");
assertThat(configuration.getConfig().get(7).getName()).isEqualTo("linked.config.0");
assertThat(configuration.getConfig().get(8).getName()).isEqualTo("linked.config.1");
assertThat(configuration.getConfig().get(9).getName()).isEqualTo("unlinked.config");
}
@Test
@ -106,7 +112,7 @@ public class RemoteConfigControllerTest {
verify(remoteConfigsManager, times(1)).getAll();
assertThat(configuration.getConfig().size()).isEqualTo(7);
assertThat(configuration.getConfig()).hasSize(10);
assertThat(configuration.getConfig().get(0).getName()).isEqualTo("android.stickers");
assertThat(configuration.getConfig().get(1).getName()).isEqualTo("ios.stickers");
assertThat(configuration.getConfig().get(2).getName()).isEqualTo("always.true");
@ -124,6 +130,31 @@ public class RemoteConfigControllerTest {
assertThat(configuration.getConfig().get(6).getName()).isEqualTo("value.always.false");
assertThat(configuration.getConfig().get(6).isEnabled()).isEqualTo(false);
assertThat(configuration.getConfig().get(6).getValue()).isEqualTo("red");
assertThat(configuration.getConfig().get(7).getName()).isEqualTo("linked.config.0");
assertThat(configuration.getConfig().get(8).getName()).isEqualTo("linked.config.1");
assertThat(configuration.getConfig().get(9).getName()).isEqualTo("unlinked.config");
}
@Test
public void testHashKeyLinkedConfigs() {
boolean allUnlinkedConfigsMatched = true;
for (AuthHelper.TestAccount testAccount : AuthHelper.TEST_ACCOUNTS) {
UserRemoteConfigList configuration = resources.getJerseyTest().target("/v1/config/").request().header("Authorization", testAccount.getAuthHeader()).get(UserRemoteConfigList.class);
assertThat(configuration.getConfig()).hasSize(10);
final UserRemoteConfig linkedConfig0 = configuration.getConfig().get(7);
assertThat(linkedConfig0.getName()).isEqualTo("linked.config.0");
final UserRemoteConfig linkedConfig1 = configuration.getConfig().get(8);
assertThat(linkedConfig1.getName()).isEqualTo("linked.config.1");
final UserRemoteConfig unlinkedConfig = configuration.getConfig().get(9);
assertThat(unlinkedConfig.getName()).isEqualTo("unlinked.config");
assertThat(linkedConfig0.isEnabled() == linkedConfig1.isEnabled()).isTrue();
allUnlinkedConfigsMatched &= (linkedConfig0.isEnabled() == unlinkedConfig.isEnabled());
}
assertThat(allUnlinkedConfigsMatched).isFalse();
}
@ -147,7 +178,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
@ -166,7 +197,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("value.sometimes", 50, Set.of(), "a", "b"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("value.sometimes", 50, Set.of(), "a", "b", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204);
@ -179,13 +210,49 @@ public class RemoteConfigControllerTest {
assertThat(captor.getValue().getUuids()).isEmpty();
}
@Test
public void testSetConfigWithHashKey() {
Response response1 = resources.getJerseyTest()
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("linked.config.0", 50, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response1.getStatus()).isEqualTo(204);
Response response2 = resources.getJerseyTest()
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("linked.config.1", 50, Set.of(), "FALSE", "TRUE", "linked.config.0"), MediaType.APPLICATION_JSON_TYPE));
assertThat(response2.getStatus()).isEqualTo(204);
ArgumentCaptor<RemoteConfig> captor = ArgumentCaptor.forClass(RemoteConfig.class);
verify(remoteConfigsManager, times(2)).set(captor.capture());
assertThat(captor.getAllValues()).hasSize(2);
final RemoteConfig capture1 = captor.getAllValues().get(0);
assertThat(capture1).isNotNull();
assertThat(capture1.getName()).isEqualTo("linked.config.0");
assertThat(capture1.getPercentage()).isEqualTo(50);
assertThat(capture1.getUuids()).isEmpty();
assertThat(capture1.getHashKey()).isNull();
final RemoteConfig capture2 = captor.getAllValues().get(1);
assertThat(capture2).isNotNull();
assertThat(capture2.getName()).isEqualTo("linked.config.1");
assertThat(capture2.getPercentage()).isEqualTo(50);
assertThat(capture2.getUuids()).isEmpty();
assertThat(capture2.getHashKey()).isEqualTo("linked.config.0");
}
@Test
public void testSetConfigUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/config")
.request()
.header("Config-Token", "baz")
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401);
@ -197,7 +264,7 @@ public class RemoteConfigControllerTest {
Response response = resources.getJerseyTest()
.target("/v1/config")
.request()
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("android.stickers", 88, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401);
@ -210,7 +277,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("android-stickers", 88, Set.of(), "FALSE", "TRUE"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("android-stickers", 88, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
@ -223,7 +290,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config")
.request()
.header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("", 88, Set.of(), "FALSE", "TRUE"), MediaType.APPLICATION_JSON_TYPE));
.put(Entity.entity(new RemoteConfig("", 88, Set.of(), "FALSE", "TRUE", null), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422);
@ -264,13 +331,13 @@ public class RemoteConfigControllerTest {
Map<String, Integer> enabledMap = new HashMap<>();
MessageDigest digest = MessageDigest.getInstance("SHA1");
int iterations = 100000;
SecureRandom secureRandom = new SecureRandom(new byte[]{42}); // the seed value doesn't matter so much as it's constant to make the test not flaky
Random random = new Random(9424242L); // the seed value doesn't matter so much as it's constant to make the test not flaky
for (int i=0;i<iterations;i++) {
for (RemoteConfig config : remoteConfigList) {
int count = enabledMap.getOrDefault(config.getName(), 0);
if (RemoteConfigController.isInBucket(digest, getRandomUUID(secureRandom), config.getName().getBytes(), config.getPercentage(), new HashSet<>())) {
if (RemoteConfigController.isInBucket(digest, AuthHelper.getRandomUUID(random), config.getName().getBytes(), config.getPercentage(), new HashSet<>())) {
count++;
}
@ -286,14 +353,4 @@ public class RemoteConfigControllerTest {
}
}
private static UUID getRandomUUID(SecureRandom secureRandom) {
long mostSignificantBits = secureRandom.nextLong();
long leastSignificantBits = secureRandom.nextLong();
mostSignificantBits &= 0xffffffffffff0fffL;
mostSignificantBits |= 0x0000000000004000L;
leastSignificantBits &= 0x3fffffffffffffffL;
leastSignificantBits |= 0x8000000000000000L;
return new UUID(mostSignificantBits, leastSignificantBits);
}
}

View File

@ -35,11 +35,11 @@ public class RemoteConfigsManagerTest {
@Test
public void testUpdate() throws InterruptedException {
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.sometimes", 50, Set.of(), "bar", "baz"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.sometimes", 25, Set.of(AuthHelper.VALID_UUID), "abc", "def"));
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.sometimes", 50, Set.of(), "bar", "baz", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.sometimes", 25, Set.of(AuthHelper.VALID_UUID), "abc", "def", null));
Thread.sleep(501);

View File

@ -32,8 +32,8 @@ public class RemoteConfigsTest {
@Test
public void testStore() {
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID, AuthHelper.VALID_UUID_TWO), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.sometimes", 25, Set.of(AuthHelper.VALID_UUID_TWO), "default", "custom"));
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID, AuthHelper.VALID_UUID_TWO), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.sometimes", 25, Set.of(AuthHelper.VALID_UUID_TWO), "default", "custom", null));
List<RemoteConfig> configs = remoteConfigs.getAll();
@ -60,11 +60,11 @@ public class RemoteConfigsTest {
@Test
public void testUpdate() {
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.sometimes", 22, Set.of(), "def", "!"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(AuthHelper.DISABLED_UUID), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.sometimes", 77, Set.of(), "hey", "wut"));
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.sometimes", 22, Set.of(), "def", "!", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(AuthHelper.DISABLED_UUID), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.sometimes", 77, Set.of(), "hey", "wut", null));
List<RemoteConfig> configs = remoteConfigs.getAll();
@ -91,10 +91,10 @@ public class RemoteConfigsTest {
@Test
public void testDelete() {
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE"));
remoteConfigs.set(new RemoteConfig("value.always", 100, Set.of(), "never", "always"));
remoteConfigs.set(new RemoteConfig("android.stickers", 50, Set.of(AuthHelper.VALID_UUID), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, Set.of(), "FALSE", "TRUE", null));
remoteConfigs.set(new RemoteConfig("value.always", 100, Set.of(), "never", "always", null));
remoteConfigs.delete("android.stickers");
List<RemoteConfig> configs = remoteConfigs.getAll();

View File

@ -1,6 +1,10 @@
package org.whispersystems.textsecuregcm.tests.util;
import com.google.common.collect.ImmutableMap;
import io.dropwizard.auth.AuthFilter;
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.auth.basic.BasicCredentials;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
@ -12,19 +16,22 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Base64;
import java.security.Principal;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import io.dropwizard.auth.AuthFilter;
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class AuthHelper {
// Static seed to ensure reproducible tests.
private static final Random random = new Random(0xf744df3b43a3339cL);
public static final TestAccount[] TEST_ACCOUNTS = generateTestAccounts();
public static final String VALID_NUMBER = "+14150000000";
public static final UUID VALID_UUID = UUID.randomUUID();
public static final String VALID_PASSWORD = "foo";
@ -56,7 +63,7 @@ public class AuthHelper {
private static AuthenticationCredentials VALID_CREDENTIALS_TWO = mock(AuthenticationCredentials.class);
private static AuthenticationCredentials DISABLED_CREDENTIALS = mock(AuthenticationCredentials.class);
public static PolymorphicAuthDynamicFeature getAuthFilter() {
public static PolymorphicAuthDynamicFeature<? extends Principal> getAuthFilter() {
when(VALID_CREDENTIALS.verify("foo")).thenReturn(true);
when(VALID_CREDENTIALS_TWO.verify("baz")).thenReturn(true);
when(DISABLED_CREDENTIALS.verify(DISABLED_PASSWORD)).thenReturn(true);
@ -118,6 +125,10 @@ public class AuthHelper {
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
for (TestAccount testAccount : TEST_ACCOUNTS) {
testAccount.setup(ACCOUNTS_MANAGER);
}
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
@ -132,4 +143,61 @@ public class AuthHelper {
public static String getUnidentifiedAccessHeader(byte[] key) {
return Base64.encodeBytes(key);
}
public static UUID getRandomUUID(Random random) {
long mostSignificantBits = random.nextLong();
long leastSignificantBits = random.nextLong();
mostSignificantBits &= 0xffffffffffff0fffL;
mostSignificantBits |= 0x0000000000004000L;
leastSignificantBits &= 0x3fffffffffffffffL;
leastSignificantBits |= 0x8000000000000000L;
return new UUID(mostSignificantBits, leastSignificantBits);
}
public static final class TestAccount {
public final String number;
public final UUID uuid;
public final String password;
public final Account account = mock(Account.class);
public final Device device = mock(Device.class);
public final AuthenticationCredentials authenticationCredentials = mock(AuthenticationCredentials.class);
public TestAccount(String number, UUID uuid, String password) {
this.number = number;
this.uuid = uuid;
this.password = password;
}
public String getAuthHeader() {
return AuthHelper.getAuthHeader(number, password);
}
private void setup(final AccountsManager accountsManager) {
when(authenticationCredentials.verify(password)).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(authenticationCredentials);
when(device.isMaster()).thenReturn(true);
when(device.getId()).thenReturn(1L);
when(device.isEnabled()).thenReturn(true);
when(account.getDevice(1L)).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getRelay()).thenReturn(Optional.empty());
when(account.isEnabled()).thenReturn(true);
when(accountsManager.get(number)).thenReturn(Optional.of(account));
when(accountsManager.get(uuid)).thenReturn(Optional.of(account));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(number)))).thenReturn(Optional.of(account));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account));
}
}
private static TestAccount[] generateTestAccounts() {
final TestAccount[] testAccounts = new TestAccount[20];
final long numberBase = 1_409_000_0000L;
for (int i = 0; i < testAccounts.length; i++) {
long currentNumber = numberBase + i;
testAccounts[i] = new TestAccount("+" + currentNumber, getRandomUUID(random), "TestAccountPassword-" + currentNumber);
}
return testAccounts;
}
}