Add support for UUID buckets in remote config

This commit is contained in:
Moxie Marlinspike 2020-01-21 14:38:47 -08:00
parent 08a70664f4
commit e4e20c2d25
8 changed files with 135 additions and 31 deletions

View File

@ -21,9 +21,12 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.nio.ByteBuffer;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
@ -46,12 +49,12 @@ public class RemoteConfigController {
public UserRemoteConfigList getAll(@Auth Account account) { public UserRemoteConfigList getAll(@Auth Account account) {
try { try {
MessageDigest digest = MessageDigest.getInstance("SHA1"); MessageDigest digest = MessageDigest.getInstance("SHA1");
byte[] number = account.getNumber().getBytes();
return new UserRemoteConfigList(remoteConfigsManager.getAll().stream().map(config -> new UserRemoteConfig(config.getName(), return new UserRemoteConfigList(remoteConfigsManager.getAll().stream().map(config -> new UserRemoteConfig(config.getName(),
isInBucket(digest, number, isInBucket(digest, account.getUuid(),
config.getName().getBytes(), config.getName().getBytes(),
config.getPercentage()))) config.getPercentage(),
config.getUuids())))
.collect(Collectors.toList())); .collect(Collectors.toList()));
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
throw new AssertionError(e); throw new AssertionError(e);
@ -82,8 +85,14 @@ public class RemoteConfigController {
} }
@VisibleForTesting @VisibleForTesting
public static boolean isInBucket(MessageDigest digest, byte[] user, byte[] configName, int configPercentage) { public static boolean isInBucket(MessageDigest digest, UUID uid, byte[] configName, int configPercentage, Set<UUID> uuidsInBucket) {
digest.update(user); if (uuidsInBucket.contains(uid)) return true;
ByteBuffer bb = ByteBuffer.wrap(new byte[16]);
bb.putLong(uid.getMostSignificantBits());
bb.putLong(uid.getLeastSignificantBits());
digest.update(bb.array());
byte[] hash = digest.digest(configName); byte[] hash = digest.digest(configName);
int bucket = (int)(Math.abs(Conversions.byteArrayToLong(hash)) % 100); int bucket = (int)(Math.abs(Conversions.byteArrayToLong(hash)) % 100);
@ -93,7 +102,7 @@ public class RemoteConfigController {
@SuppressWarnings("BooleanMethodIsAlwaysInverted") @SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean isAuthorized(String configToken) { private boolean isAuthorized(String configToken) {
return configAuthTokens.stream().anyMatch(authorized -> MessageDigest.isEqual(authorized.getBytes(), configToken.getBytes())); return configToken != null && configAuthTokens.stream().anyMatch(authorized -> MessageDigest.isEqual(authorized.getBytes(), configToken.getBytes()));
} }
} }

View File

@ -6,6 +6,11 @@ import javax.validation.constraints.Max;
import javax.validation.constraints.Min; import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.validation.constraints.Pattern; import javax.validation.constraints.Pattern;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.UUID;
public class RemoteConfig { public class RemoteConfig {
@ -19,11 +24,16 @@ public class RemoteConfig {
@Max(100) @Max(100)
private int percentage; private int percentage;
@JsonProperty
@NotNull
private Set<UUID> uuids = new HashSet<>();
public RemoteConfig() {} public RemoteConfig() {}
public RemoteConfig(String name, int percentage) { public RemoteConfig(String name, int percentage, Set<UUID> uuids) {
this.name = name; this.name = name;
this.percentage = percentage; this.percentage = percentage;
this.uuids = uuids;
} }
public int getPercentage() { public int getPercentage() {
@ -33,4 +43,8 @@ public class RemoteConfig {
public String getName() { public String getName() {
return name; return name;
} }
public Set<UUID> getUuids() {
return uuids;
}
} }

View File

@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.storage.mappers.RemoteConfigRowMapper;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -22,8 +23,7 @@ public class RemoteConfigs {
public static final String ID = "id"; public static final String ID = "id";
public static final String NAME = "name"; public static final String NAME = "name";
public static final String PERCENTAGE = "percentage"; public static final String PERCENTAGE = "percentage";
public static final String UUIDS = "uuids";
private static final ObjectMapper mapper = SystemMapper.getMapper();
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer setTimer = metricRegistry.timer(name(Accounts.class, "set" )); private final Timer setTimer = metricRegistry.timer(name(Accounts.class, "set" ));
@ -35,14 +35,16 @@ public class RemoteConfigs {
public RemoteConfigs(FaultTolerantDatabase database) { public RemoteConfigs(FaultTolerantDatabase database) {
this.database = database; this.database = database;
this.database.getDatabase().registerRowMapper(new RemoteConfigRowMapper()); this.database.getDatabase().registerRowMapper(new RemoteConfigRowMapper());
this.database.getDatabase().registerArrayType(UUID.class, "uuid");
} }
public void set(RemoteConfig remoteConfig) { public void set(RemoteConfig remoteConfig) {
database.use(jdbi -> jdbi.useHandle(handle -> { database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = setTimer.time()) { try (Timer.Context ignored = setTimer.time()) {
handle.createUpdate("INSERT INTO remote_config (" + NAME + ", " + PERCENTAGE + ") VALUES (:name, :percentage) ON CONFLICT(" + NAME + ") DO UPDATE SET " + PERCENTAGE + " = EXCLUDED." + PERCENTAGE) handle.createUpdate("INSERT INTO remote_config (" + NAME + ", " + PERCENTAGE + ", " + UUIDS + ") VALUES (:name, :percentage, :uuids) ON CONFLICT(" + NAME + ") DO UPDATE SET " + PERCENTAGE + " = EXCLUDED." + PERCENTAGE + ", " + UUIDS + " = EXCLUDED." + UUIDS)
.bind("name", remoteConfig.getName()) .bind("name", remoteConfig.getName())
.bind("percentage", remoteConfig.getPercentage()) .bind("percentage", remoteConfig.getPercentage())
.bind("uuids", remoteConfig.getUuids().toArray(new UUID[0]))
.execute(); .execute();
} }
})); }));

View File

@ -7,12 +7,15 @@ import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.UUID;
public class RemoteConfigRowMapper implements RowMapper<RemoteConfig> { public class RemoteConfigRowMapper implements RowMapper<RemoteConfig> {
@Override @Override
public RemoteConfig map(ResultSet rs, StatementContext ctx) throws SQLException { public RemoteConfig map(ResultSet rs, StatementContext ctx) throws SQLException {
return new RemoteConfig(rs.getString(RemoteConfigs.NAME), rs.getInt(RemoteConfigs.PERCENTAGE)); return new RemoteConfig(rs.getString(RemoteConfigs.NAME), rs.getInt(RemoteConfigs.PERCENTAGE), new HashSet<>(Arrays.asList((UUID[])rs.getArray(RemoteConfigs.UUIDS).getArray())));
} }
} }

View File

@ -287,7 +287,7 @@
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="uuids" type="text []"> <column name="uuids" type="uuid[]">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
</createTable> </createTable>

View File

@ -22,9 +22,11 @@ import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule; import io.dropwizard.testing.junit.ResourceTestRule;
@ -52,9 +54,22 @@ public class RemoteConfigControllerTest {
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
when(remoteConfigsManager.getAll()).thenReturn(new LinkedList<>() {{ when(remoteConfigsManager.getAll()).thenReturn(new LinkedList<>() {{
add(new RemoteConfig("android.stickers", 25)); add(new RemoteConfig("android.stickers", 25, new HashSet<>() {{
add(new RemoteConfig("ios.stickers", 50)); add(AuthHelper.DISABLED_UUID);
add(new RemoteConfig("always.true", 100)); add(AuthHelper.INVALID_UUID);
}}));
add(new RemoteConfig("ios.stickers", 50, new HashSet<>() {{
}}));
add(new RemoteConfig("always.true", 100, new HashSet<>() {{
}}));
add(new RemoteConfig("only.special", 0, new HashSet<>() {{
add(AuthHelper.VALID_UUID);
}}));
}}); }});
} }
@ -68,13 +83,33 @@ public class RemoteConfigControllerTest {
verify(remoteConfigsManager, times(1)).getAll(); verify(remoteConfigsManager, times(1)).getAll();
assertThat(configuration.getConfig().size()).isEqualTo(3); assertThat(configuration.getConfig().size()).isEqualTo(4);
assertThat(configuration.getConfig().get(0).getName()).isEqualTo("android.stickers"); assertThat(configuration.getConfig().get(0).getName()).isEqualTo("android.stickers");
assertThat(configuration.getConfig().get(1).getName()).isEqualTo("ios.stickers"); assertThat(configuration.getConfig().get(1).getName()).isEqualTo("ios.stickers");
assertThat(configuration.getConfig().get(2).getName()).isEqualTo("always.true"); assertThat(configuration.getConfig().get(2).getName()).isEqualTo("always.true");
assertThat(configuration.getConfig().get(2).isEnabled()).isEqualTo(true); assertThat(configuration.getConfig().get(2).isEnabled()).isEqualTo(true);
assertThat(configuration.getConfig().get(3).isEnabled()).isEqualTo(true);
} }
@Test
public void testRetrieveConfigNotSpecial() {
UserRemoteConfigList configuration = resources.getJerseyTest()
.target("/v1/config/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER_TWO, AuthHelper.VALID_PASSWORD_TWO))
.get(UserRemoteConfigList.class);
verify(remoteConfigsManager, times(1)).getAll();
assertThat(configuration.getConfig().size()).isEqualTo(4);
assertThat(configuration.getConfig().get(0).getName()).isEqualTo("android.stickers");
assertThat(configuration.getConfig().get(1).getName()).isEqualTo("ios.stickers");
assertThat(configuration.getConfig().get(2).getName()).isEqualTo("always.true");
assertThat(configuration.getConfig().get(2).isEnabled()).isEqualTo(true);
assertThat(configuration.getConfig().get(3).isEnabled()).isEqualTo(false);
}
@Test @Test
public void testRetrieveConfigUnauthorized() { public void testRetrieveConfigUnauthorized() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
@ -95,7 +130,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config") .target("/v1/config")
.request() .request()
.header("Config-Token", "foo") .header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("android.stickers", 88), MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(new RemoteConfig("android.stickers", 88, new HashSet<>()), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
@ -105,6 +140,7 @@ public class RemoteConfigControllerTest {
assertThat(captor.getValue().getName()).isEqualTo("android.stickers"); assertThat(captor.getValue().getName()).isEqualTo("android.stickers");
assertThat(captor.getValue().getPercentage()).isEqualTo(88); assertThat(captor.getValue().getPercentage()).isEqualTo(88);
assertThat(captor.getValue().getUuids()).isEmpty();
} }
@Test @Test
@ -113,7 +149,19 @@ public class RemoteConfigControllerTest {
.target("/v1/config") .target("/v1/config")
.request() .request()
.header("Config-Token", "baz") .header("Config-Token", "baz")
.put(Entity.entity(new RemoteConfig("android.stickers", 88), MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(new RemoteConfig("android.stickers", 88, new HashSet<>()), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401);
verifyNoMoreInteractions(remoteConfigsManager);
}
@Test
public void testSetConfigMissingUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/config")
.request()
.put(Entity.entity(new RemoteConfig("android.stickers", 88, new HashSet<>()), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
@ -126,7 +174,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config") .target("/v1/config")
.request() .request()
.header("Config-Token", "foo") .header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("android-stickers", 88), MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(new RemoteConfig("android-stickers", 88, new HashSet<>()), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
@ -139,7 +187,7 @@ public class RemoteConfigControllerTest {
.target("/v1/config") .target("/v1/config")
.request() .request()
.header("Config-Token", "foo") .header("Config-Token", "foo")
.put(Entity.entity(new RemoteConfig("", 88), MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(new RemoteConfig("", 88, new HashSet<>()), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(422); assertThat(response.getStatus()).isEqualTo(422);
@ -186,7 +234,7 @@ public class RemoteConfigControllerTest {
int count = enabledMap.getOrDefault(config.getName(), 0); int count = enabledMap.getOrDefault(config.getName(), 0);
int random = new SecureRandom().nextInt(iterations); int random = new SecureRandom().nextInt(iterations);
if (RemoteConfigController.isInBucket(digest, ("+121322" + String.format("%05d", random)).getBytes(), config.getName().getBytes(), config.getPercentage())) { if (RemoteConfigController.isInBucket(digest, UUID.randomUUID(), config.getName().getBytes(), config.getPercentage(), new HashSet<>())) {
count++; count++;
} }

View File

@ -12,9 +12,12 @@ import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.RemoteConfig; import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs; import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.util.HashSet;
import java.util.List; import java.util.List;
import io.dropwizard.auth.Auth;
import static org.assertj.core.api.Java6Assertions.assertThat; import static org.assertj.core.api.Java6Assertions.assertThat;
public class RemoteConfigsManagerTest { public class RemoteConfigsManagerTest {
@ -33,9 +36,11 @@ public class RemoteConfigsManagerTest {
@Test @Test
public void testUpdate() throws InterruptedException { public void testUpdate() throws InterruptedException {
remoteConfigs.set(new RemoteConfig("android.stickers", 50)); remoteConfigs.set(new RemoteConfig("android.stickers", 50, new HashSet<>() {{
remoteConfigs.set(new RemoteConfig("ios.stickers", 50)); add(AuthHelper.VALID_UUID);
remoteConfigs.set(new RemoteConfig("ios.stickers", 75)); }}));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, new HashSet<>()));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, new HashSet<>()));
Thread.sleep(501); Thread.sleep(501);
@ -44,8 +49,12 @@ public class RemoteConfigsManagerTest {
assertThat(results.size()).isEqualTo(2); assertThat(results.size()).isEqualTo(2);
assertThat(results.get(0).getName()).isEqualTo("android.stickers"); assertThat(results.get(0).getName()).isEqualTo("android.stickers");
assertThat(results.get(0).getPercentage()).isEqualTo(50); assertThat(results.get(0).getPercentage()).isEqualTo(50);
assertThat(results.get(0).getUuids().size()).isEqualTo(1);
assertThat(results.get(0).getUuids().contains(AuthHelper.VALID_UUID)).isTrue();
assertThat(results.get(1).getName()).isEqualTo("ios.stickers"); assertThat(results.get(1).getName()).isEqualTo("ios.stickers");
assertThat(results.get(1).getPercentage()).isEqualTo(75); assertThat(results.get(1).getPercentage()).isEqualTo(75);
assertThat(results.get(1).getUuids()).isEmpty();
} }

View File

@ -11,10 +11,13 @@ import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguratio
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.RemoteConfig; import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs; import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.HashSet;
import java.util.List; import java.util.List;
import io.dropwizard.auth.Auth;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
public class RemoteConfigsTest { public class RemoteConfigsTest {
@ -31,35 +34,51 @@ public class RemoteConfigsTest {
@Test @Test
public void testStore() throws SQLException { public void testStore() throws SQLException {
remoteConfigs.set(new RemoteConfig("android.stickers", 50)); remoteConfigs.set(new RemoteConfig("android.stickers", 50, new HashSet<>() {{
add(AuthHelper.VALID_UUID);
add(AuthHelper.VALID_UUID_TWO);
}}));
List<RemoteConfig> configs = remoteConfigs.getAll(); List<RemoteConfig> configs = remoteConfigs.getAll();
assertThat(configs.size()).isEqualTo(1); assertThat(configs.size()).isEqualTo(1);
assertThat(configs.get(0).getName()).isEqualTo("android.stickers"); assertThat(configs.get(0).getName()).isEqualTo("android.stickers");
assertThat(configs.get(0).getPercentage()).isEqualTo(50); assertThat(configs.get(0).getPercentage()).isEqualTo(50);
assertThat(configs.get(0).getUuids().size()).isEqualTo(2);
assertThat(configs.get(0).getUuids().contains(AuthHelper.VALID_UUID)).isTrue();
assertThat(configs.get(0).getUuids().contains(AuthHelper.VALID_UUID_TWO)).isTrue();
assertThat(configs.get(0).getUuids().contains(AuthHelper.INVALID_UUID)).isFalse();
} }
@Test @Test
public void testUpdate() throws SQLException { public void testUpdate() throws SQLException {
remoteConfigs.set(new RemoteConfig("android.stickers", 50)); remoteConfigs.set(new RemoteConfig("android.stickers", 50, new HashSet<>()));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75)); remoteConfigs.set(new RemoteConfig("ios.stickers", 50, new HashSet<>() {{
add(AuthHelper.DISABLED_UUID);
}}));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, new HashSet<>()));
List<RemoteConfig> configs = remoteConfigs.getAll(); List<RemoteConfig> configs = remoteConfigs.getAll();
assertThat(configs.size()).isEqualTo(2); assertThat(configs.size()).isEqualTo(2);
assertThat(configs.get(0).getName()).isEqualTo("android.stickers"); assertThat(configs.get(0).getName()).isEqualTo("android.stickers");
assertThat(configs.get(0).getPercentage()).isEqualTo(50); assertThat(configs.get(0).getPercentage()).isEqualTo(50);
assertThat(configs.get(0).getUuids().size()).isEqualTo(0);
assertThat(configs.get(1).getName()).isEqualTo("ios.stickers"); assertThat(configs.get(1).getName()).isEqualTo("ios.stickers");
assertThat(configs.get(1).getPercentage()).isEqualTo(75); assertThat(configs.get(1).getPercentage()).isEqualTo(75);
assertThat(configs.get(1).getUuids().size()).isEqualTo(0);
} }
@Test @Test
public void testDelete() { public void testDelete() {
remoteConfigs.set(new RemoteConfig("android.stickers", 50)); remoteConfigs.set(new RemoteConfig("android.stickers", 50, new HashSet<>() {{
remoteConfigs.set(new RemoteConfig("ios.stickers", 50)); add(AuthHelper.VALID_UUID);
remoteConfigs.set(new RemoteConfig("ios.stickers", 75)); }}));
remoteConfigs.set(new RemoteConfig("ios.stickers", 50, new HashSet<>()));
remoteConfigs.set(new RemoteConfig("ios.stickers", 75, new HashSet<>()));
remoteConfigs.delete("android.stickers"); remoteConfigs.delete("android.stickers");
List<RemoteConfig> configs = remoteConfigs.getAll(); List<RemoteConfig> configs = remoteConfigs.getAll();