Prepare to read profile data stored as byte arrays

This commit is contained in:
Katherine Yen 2023-08-10 14:00:35 -07:00 committed by GitHub
parent bc5eed48c3
commit a71dc48b9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 252 additions and 196 deletions

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.util.ArrayList; import java.util.ArrayList;
@ -77,12 +76,7 @@ public class Profiles {
private static final Timer SET_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "set")); private static final Timer SET_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "set"));
private static final Timer GET_PROFILE_TIMER = Metrics.timer(name(Profiles.class, "get")); private static final Timer GET_PROFILE_TIMER = Metrics.timer(name(Profiles.class, "get"));
private static final Timer DELETE_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "delete")); private static final Timer DELETE_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "delete"));
private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(Profiles.class, "parseByteArray");
private static final Counter INVALID_NAME_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "name");
private static final Counter INVALID_EMOJI_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "emoji");
private static final Counter INVALID_ABOUT_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "about");
private static final Counter INVALID_PAYMENT_ADDRESS_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "paymentAddress");
public Profiles(final DynamoDbClient dynamoDbClient, public Profiles(final DynamoDbClient dynamoDbClient,
final DynamoDbAsyncClient dynamoDbAsyncClient, final DynamoDbAsyncClient dynamoDbAsyncClient,
@ -232,34 +226,23 @@ public class Profiles {
} }
private static VersionedProfile fromItem(final Map<String, AttributeValue> item) { private static VersionedProfile fromItem(final Map<String, AttributeValue> item) {
final String name = AttributeValues.getString(item, ATTR_NAME, null);
final String emoji = AttributeValues.getString(item, ATTR_EMOJI, null);
final String about = AttributeValues.getString(item, ATTR_ABOUT, null);
final String paymentAddress = AttributeValues.getString(item, ATTR_PAYMENT_ADDRESS, null);
checkValidBase64(name, INVALID_NAME_COUNTER);
checkValidBase64(emoji, INVALID_EMOJI_COUNTER);
checkValidBase64(about, INVALID_ABOUT_COUNTER);
checkValidBase64(paymentAddress, INVALID_PAYMENT_ADDRESS_COUNTER);
return new VersionedProfile( return new VersionedProfile(
AttributeValues.getString(item, ATTR_VERSION, null), AttributeValues.getString(item, ATTR_VERSION, null),
name, getBase64EncodedBytes(item, ATTR_NAME, PARSE_BYTE_ARRAY_COUNTER_NAME),
AttributeValues.getString(item, ATTR_AVATAR, null), AttributeValues.getString(item, ATTR_AVATAR, null),
emoji, getBase64EncodedBytes(item, ATTR_EMOJI, PARSE_BYTE_ARRAY_COUNTER_NAME),
about, getBase64EncodedBytes(item, ATTR_ABOUT, PARSE_BYTE_ARRAY_COUNTER_NAME),
paymentAddress, getBase64EncodedBytes(item, ATTR_PAYMENT_ADDRESS, PARSE_BYTE_ARRAY_COUNTER_NAME),
AttributeValues.getByteArray(item, ATTR_COMMITMENT, null)); AttributeValues.getByteArray(item, ATTR_COMMITMENT, null));
} }
private static void checkValidBase64(final String value, final Counter counter) { private static String getBase64EncodedBytes(final Map<String, AttributeValue> item, final String attributeName, final String counterName) {
if (StringUtils.isNotBlank(value)) { final AttributeValue attributeValue = item.get(attributeName);
try {
Base64.getDecoder().decode(value); if (attributeValue == null) {
} catch (final IllegalArgumentException e) { return null;
counter.increment();
}
} }
return Base64.getEncoder().encodeToString(AttributeValues.extractByteArray(attributeValue, counterName));
} }
public void deleteAll(final UUID uuid) { public void deleteAll(final UUID uuid) {

View File

@ -14,7 +14,10 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> { public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(SingleUseECPreKeyStore.class, "parseByteArray");
protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) {
super(dynamoDbAsyncClient, tableName); super(dynamoDbAsyncClient, tableName);
@ -31,7 +34,7 @@ public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
@Override @Override
protected ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) { protected ECPreKey getPreKeyFromItem(final Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); final byte[] publicKey = AttributeValues.extractByteArray(item.get(ATTR_PUBLIC_KEY), PARSE_BYTE_ARRAY_COUNTER_NAME);
try { try {
return new ECPreKey(keyId, new ECPublicKey(publicKey)); return new ECPreKey(keyId, new ECPublicKey(publicKey));

View File

@ -73,9 +73,6 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
private final String takeKeyTimerName = name(getClass(), "takeKey"); private final String takeKeyTimerName = name(getClass(), "takeKey");
private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; private static final String KEY_PRESENT_TAG_NAME = "keyPresent";
private final Counter parseBytesFromStringCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "string");
private final Counter readBytesFromByteArrayCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "bytes");
static final String KEY_ACCOUNT_UUID = "U"; static final String KEY_ACCOUNT_UUID = "U";
static final String KEY_DEVICE_ID_KEY_ID = "DK"; static final String KEY_DEVICE_ID_KEY_ID = "DK";
static final String ATTR_PUBLIC_KEY = "P"; static final String ATTR_PUBLIC_KEY = "P";
@ -289,24 +286,4 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
final K preKey); final K preKey);
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item); protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
readBytesFromByteArrayCounter.increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
parseBytesFromStringCounter.increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
} }

View File

@ -6,9 +6,13 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -124,4 +128,24 @@ public class AttributeValues {
public static UUID getUUID(Map<String, AttributeValue> item, String key, UUID defaultValue) { public static UUID getUUID(Map<String, AttributeValue> item, String key, UUID defaultValue) {
return AttributeValues.get(item, key).filter(av -> av.b() != null).map(AttributeValues::toUUID).orElse(defaultValue); return AttributeValues.get(item, key).filter(av -> av.b() != null).map(AttributeValues::toUUID).orElse(defaultValue);
} }
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
public static byte[] extractByteArray(final AttributeValue attributeValue, final String counterName) {
if (attributeValue.b() != null) {
Metrics.counter(counterName, "format", "bytes").increment();
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
Metrics.counter(counterName, "format", "string").increment();
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
} }

View File

@ -12,6 +12,9 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -19,78 +22,85 @@ import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import java.nio.charset.StandardCharsets; import java.util.Base64;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Random;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream; import java.util.stream.Stream;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class ProfilesTest { public class ProfilesTest {
private static final UUID ACI = UUID.randomUUID();
@RegisterExtension @RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.PROFILES); static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.PROFILES);
private Profiles profiles; private Profiles profiles;
private VersionedProfile validProfile;
@BeforeEach @BeforeEach
void setUp() { void setUp() throws InvalidInputException {
profiles = new Profiles(DYNAMO_DB_EXTENSION.getDynamoDbClient(), profiles = new Profiles(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.PROFILES.tableName()); Tables.PROFILES.tableName());
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(ACI)).serialize();
final String version = "someVersion";
final String name = generateRandomBase64FromByteArray(81);
final String validAboutEmoji = generateRandomBase64FromByteArray(60);
final String validAbout = generateRandomBase64FromByteArray(156);
final String avatar = "profiles/" + generateRandomBase64FromByteArray(16);
validProfile = new VersionedProfile(version, name, avatar, validAboutEmoji, validAbout, null, commitment);
} }
@Test @Test
void testSetGet() { void testSetGet() {
UUID uuid = UUID.randomUUID(); profiles.set(ACI, validProfile);
VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", "emoji",
"the very model of a modern major general",
null, "acommitment".getBytes());
profiles.set(uuid, profile);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); Optional<VersionedProfile> retrieved = profiles.get(ACI, validProfile.getVersion());
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName());
assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar());
assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment());
assertThat(retrieved.get().getAbout()).isEqualTo(profile.getAbout()); assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout());
assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji());
} }
@Test @Test
void testSetGetAsync() { void testSetGetAsync() {
UUID uuid = UUID.randomUUID(); profiles.setAsync(ACI, validProfile).join();
VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", "emoji",
"the very model of a modern major general",
null, "acommitment".getBytes());
profiles.setAsync(uuid, profile).join();
Optional<VersionedProfile> retrieved = profiles.getAsync(uuid, "123").join(); Optional<VersionedProfile> retrieved = profiles.getAsync(ACI, validProfile.getVersion()).join();
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName());
assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar());
assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment());
assertThat(retrieved.get().getAbout()).isEqualTo(profile.getAbout()); assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout());
assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji());
} }
@Test @Test
void testDeleteReset() { void testDeleteReset() throws InvalidInputException {
UUID uuid = UUID.randomUUID(); profiles.set(ACI, validProfile);
profiles.set(uuid, new VersionedProfile("123", "foo", "avatarLocation", "emoji",
"the very model of a modern major general",
null, "acommitment".getBytes()));
profiles.deleteAll(uuid); profiles.deleteAll(ACI);
VersionedProfile updatedProfile = new VersionedProfile("123", "name", "differentAvatarLocation", final String version = "someVersion";
"differentEmoji", "changed text", "paymentAddress", "differentcommitment".getBytes(StandardCharsets.UTF_8)); final String name = generateRandomBase64FromByteArray(81);
final String differentAvatar = "profiles/" + generateRandomBase64FromByteArray(16);
final String differentEmoji = generateRandomBase64FromByteArray(60);
final String differentAbout = generateRandomBase64FromByteArray(156);
final String paymentAddress = generateRandomBase64FromByteArray(582);
final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
profiles.set(uuid, updatedProfile); VersionedProfile updatedProfile = new VersionedProfile(version, name, differentAvatar,
differentEmoji, differentAbout, paymentAddress, commitment);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); profiles.set(ACI, updatedProfile);
Optional<VersionedProfile> retrieved = profiles.get(ACI, version);
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(updatedProfile.getName()); assertThat(retrieved.get().getName()).isEqualTo(updatedProfile.getName());
@ -101,13 +111,16 @@ public class ProfilesTest {
} }
@Test @Test
void testSetGetNullOptionalFields() { void testSetGetNullOptionalFields() throws InvalidInputException {
UUID uuid = UUID.randomUUID(); final String version = "someVersion";
VersionedProfile profile = new VersionedProfile("123", "foo", null, null, null, null, final String name = generateRandomBase64FromByteArray(81);
"acommitment".getBytes()); final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
profiles.set(uuid, profile);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); VersionedProfile profile = new VersionedProfile(version, name, null, null, null, null,
commitment);
profiles.set(ACI, profile);
Optional<VersionedProfile> retrieved = profiles.get(ACI, version);
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); assertThat(retrieved.get().getName()).isEqualTo(profile.getName());
@ -118,26 +131,30 @@ public class ProfilesTest {
} }
@Test @Test
void testSetReplace() { void testSetReplace() throws InvalidInputException {
UUID uuid = UUID.randomUUID(); profiles.set(ACI, validProfile);
VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", null, null,
"paymentAddress", "acommitment".getBytes());
profiles.set(uuid, profile);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); Optional<VersionedProfile> retrieved = profiles.get(ACI, validProfile.getVersion());
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName());
assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar());
assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment());
assertThat(retrieved.get().getAbout()).isNull(); assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout());
assertThat(retrieved.get().getAboutEmoji()).isNull(); assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji());
assertThat(retrieved.get().getPaymentAddress()).isNull();
VersionedProfile updated = new VersionedProfile("123", "bar", "baz", "emoji", "bio", null, final String differentName = generateRandomBase64FromByteArray(81);
"boof".getBytes()); final String differentEmoji = generateRandomBase64FromByteArray(60);
profiles.set(uuid, updated); final String differentAbout = generateRandomBase64FromByteArray(156);
final String differentAvatar = "profiles/" + generateRandomBase64FromByteArray(16);
final byte[] differentCommitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
retrieved = profiles.get(uuid, "123"); VersionedProfile updated = new VersionedProfile(validProfile.getVersion(), differentName, differentAvatar, differentEmoji, differentAbout, null,
differentCommitment);
profiles.set(ACI, updated);
retrieved = profiles.get(ACI, updated.getVersion());
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(updated.getName()); assertThat(retrieved.get().getName()).isEqualTo(updated.getName());
@ -146,22 +163,34 @@ public class ProfilesTest {
assertThat(retrieved.get().getAvatar()).isEqualTo(updated.getAvatar()); assertThat(retrieved.get().getAvatar()).isEqualTo(updated.getAvatar());
// Commitment should be unchanged after an overwrite // Commitment should be unchanged after an overwrite
assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment());
} }
@Test @Test
void testMultipleVersions() { void testMultipleVersions() throws InvalidInputException {
UUID uuid = UUID.randomUUID(); final String versionOne = "versionOne";
VersionedProfile profileOne = new VersionedProfile("123", "foo", "avatarLocation", null, null, final String versionTwo = "versionTwo";
null, "acommitmnet".getBytes());
VersionedProfile profileTwo = new VersionedProfile("345", "bar", "baz", "emoji",
"i keep typing emoju for some reason",
null, "boof".getBytes());
profiles.set(uuid, profileOne); final String nameOne = generateRandomBase64FromByteArray(81);
profiles.set(uuid, profileTwo); final String nameTwo = generateRandomBase64FromByteArray(81);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); final String avatarOne = "profiles/" + generateRandomBase64FromByteArray(16);
final String avatarTwo = "profiles/" + generateRandomBase64FromByteArray(16);
final String aboutEmoji = generateRandomBase64FromByteArray(60);
final String about = generateRandomBase64FromByteArray(156);
final byte[] commitmentOne = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
final byte[] commitmentTwo = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null,
null, commitmentOne);
VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, commitmentTwo);
profiles.set(ACI, profileOne);
profiles.set(ACI, profileTwo);
Optional<VersionedProfile> retrieved = profiles.get(ACI, versionOne);
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profileOne.getName()); assertThat(retrieved.get().getName()).isEqualTo(profileOne.getName());
@ -170,7 +199,7 @@ public class ProfilesTest {
assertThat(retrieved.get().getAbout()).isEqualTo(profileOne.getAbout()); assertThat(retrieved.get().getAbout()).isEqualTo(profileOne.getAbout());
assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profileOne.getAboutEmoji()); assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profileOne.getAboutEmoji());
retrieved = profiles.get(uuid, "345"); retrieved = profiles.get(ACI, versionTwo);
assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.isPresent()).isTrue();
assertThat(retrieved.get().getName()).isEqualTo(profileTwo.getName()); assertThat(retrieved.get().getName()).isEqualTo(profileTwo.getName());
@ -182,33 +211,45 @@ public class ProfilesTest {
@Test @Test
void testMissing() { void testMissing() {
UUID uuid = UUID.randomUUID(); profiles.set(ACI, validProfile);
VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", null, null, final String missingVersion = "missingVersion";
null, "aDigest".getBytes());
profiles.set(uuid, profile);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "888"); Optional<VersionedProfile> retrieved = profiles.get(ACI, missingVersion);
assertThat(retrieved.isPresent()).isFalse(); assertThat(retrieved.isPresent()).isFalse();
} }
@Test @Test
void testDelete() { void testDelete() throws InvalidInputException {
UUID uuid = UUID.randomUUID(); final String versionOne = "versionOne";
VersionedProfile profileOne = new VersionedProfile("123", "foo", "avatarLocation", null, null, final String versionTwo = "versionTwo";
null, "aDigest".getBytes());
VersionedProfile profileTwo = new VersionedProfile("345", "bar", "baz", null, null, null, "boof".getBytes());
profiles.set(uuid, profileOne); final String nameOne = generateRandomBase64FromByteArray(81);
profiles.set(uuid, profileTwo); final String nameTwo = generateRandomBase64FromByteArray(81);
profiles.deleteAll(uuid); final String aboutEmoji = generateRandomBase64FromByteArray(60);
final String about = generateRandomBase64FromByteArray(156);
Optional<VersionedProfile> retrieved = profiles.get(uuid, "123"); final String avatarOne = "profiles/" + generateRandomBase64FromByteArray(16);
final String avatarTwo = "profiles/" + generateRandomBase64FromByteArray(16);
final byte[] commitmentOne = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
final byte[] commitmentTwo = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null,
null, commitmentOne);
VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, commitmentTwo);
profiles.set(ACI, profileOne);
profiles.set(ACI, profileTwo);
profiles.deleteAll(ACI);
Optional<VersionedProfile> retrieved = profiles.get(ACI, versionOne);
assertThat(retrieved.isPresent()).isFalse(); assertThat(retrieved.isPresent()).isFalse();
retrieved = profiles.get(uuid, "345"); retrieved = profiles.get(ACI, versionTwo);
assertThat(retrieved.isPresent()).isFalse(); assertThat(retrieved.isPresent()).isFalse();
} }
@ -219,32 +260,38 @@ public class ProfilesTest {
assertEquals(expectedUpdateExpression, Profiles.buildUpdateExpression(profile)); assertEquals(expectedUpdateExpression, Profiles.buildUpdateExpression(profile));
} }
private static Stream<Arguments> buildUpdateExpression() { private static Stream<Arguments> buildUpdateExpression() throws InvalidInputException {
final byte[] commitment = "commitment".getBytes(StandardCharsets.UTF_8); final String version = "someVersion";
final String name = generateRandomBase64FromByteArray(81);
final String avatar = "profiles/" + generateRandomBase64FromByteArray(16);;
final String emoji = generateRandomBase64FromByteArray(60);
final String about = generateRandomBase64FromByteArray(156);
final String paymentAddress = generateRandomBase64FromByteArray(582);
final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
return Stream.of( return Stream.of(
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", "about", "paymentAddress", commitment), new VersionedProfile(version, name, avatar, emoji, about, paymentAddress, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji, #paymentAddress = :paymentAddress"), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji, #paymentAddress = :paymentAddress"),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", "about", null, commitment), new VersionedProfile(version, name, avatar, emoji, about, null, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji REMOVE #paymentAddress"), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji REMOVE #paymentAddress"),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", null, null, commitment), new VersionedProfile(version, name, avatar, emoji, null, null, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #aboutEmoji = :aboutEmoji REMOVE #about, #paymentAddress"), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #aboutEmoji = :aboutEmoji REMOVE #about, #paymentAddress"),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", null, null, null, commitment), new VersionedProfile(version, name, avatar, null, null, null, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar REMOVE #about, #aboutEmoji, #paymentAddress"), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar REMOVE #about, #aboutEmoji, #paymentAddress"),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", null, null, null, null, commitment), new VersionedProfile(version, name, null, null, null, null, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment), #name = :name REMOVE #avatar, #about, #aboutEmoji, #paymentAddress"), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name REMOVE #avatar, #about, #aboutEmoji, #paymentAddress"),
Arguments.of( Arguments.of(
new VersionedProfile("version", null, null, null, null, null, commitment), new VersionedProfile(version, null, null, null, null, null, commitment),
"SET #commitment = if_not_exists(#commitment, :commitment) REMOVE #name, #avatar, #about, #aboutEmoji, #paymentAddress") "SET #commitment = if_not_exists(#commitment, :commitment) REMOVE #name, #avatar, #about, #aboutEmoji, #paymentAddress")
); );
} }
@ -255,53 +302,69 @@ public class ProfilesTest {
assertEquals(expectedAttributeValues, Profiles.buildUpdateExpressionAttributeValues(profile)); assertEquals(expectedAttributeValues, Profiles.buildUpdateExpressionAttributeValues(profile));
} }
private static Stream<Arguments> buildUpdateExpressionAttributeValues() { private static Stream<Arguments> buildUpdateExpressionAttributeValues() throws InvalidInputException {
final byte[] commitment = "commitment".getBytes(StandardCharsets.UTF_8); final String version = "someVersion";
final String name = generateRandomBase64FromByteArray(81);
final String avatar = "profiles/" + generateRandomBase64FromByteArray(16);;
final String emoji = generateRandomBase64FromByteArray(60);
final String about = generateRandomBase64FromByteArray(156);
final String paymentAddress = generateRandomBase64FromByteArray(582);
final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
return Stream.of( return Stream.of(
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", "about", "paymentAddress", commitment), new VersionedProfile(version, name, avatar, emoji, about, paymentAddress, commitment),
Map.of( Map.of(
":commitment", AttributeValues.fromByteArray(commitment), ":commitment", AttributeValues.fromByteArray(commitment),
":name", AttributeValues.fromString("name"), ":name", AttributeValues.fromString(name),
":avatar", AttributeValues.fromString("avatar"), ":avatar", AttributeValues.fromString(avatar),
":aboutEmoji", AttributeValues.fromString("emoji"), ":aboutEmoji", AttributeValues.fromString(emoji),
":about", AttributeValues.fromString("about"), ":about", AttributeValues.fromString(about),
":paymentAddress", AttributeValues.fromString("paymentAddress"))), ":paymentAddress", AttributeValues.fromString(paymentAddress))),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", "about", null, commitment), new VersionedProfile(version, name, avatar, emoji, about, null, commitment),
Map.of( Map.of(
":commitment", AttributeValues.fromByteArray(commitment), ":commitment", AttributeValues.fromByteArray(commitment),
":name", AttributeValues.fromString("name"), ":name", AttributeValues.fromString(name),
":avatar", AttributeValues.fromString("avatar"), ":avatar", AttributeValues.fromString(avatar),
":aboutEmoji", AttributeValues.fromString("emoji"), ":aboutEmoji", AttributeValues.fromString(emoji),
":about", AttributeValues.fromString("about"))), ":about", AttributeValues.fromString(about))),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", "emoji", null, null, commitment), new VersionedProfile(version, name, avatar, emoji, null, null, commitment),
Map.of( Map.of(
":commitment", AttributeValues.fromByteArray(commitment), ":commitment", AttributeValues.fromByteArray(commitment),
":name", AttributeValues.fromString("name"), ":name", AttributeValues.fromString(name),
":avatar", AttributeValues.fromString("avatar"), ":avatar", AttributeValues.fromString(avatar),
":aboutEmoji", AttributeValues.fromString("emoji"))), ":aboutEmoji", AttributeValues.fromString(emoji))),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", "avatar", null, null, null, commitment), new VersionedProfile(version, name, avatar, null, null, null, commitment),
Map.of( Map.of(
":commitment", AttributeValues.fromByteArray(commitment), ":commitment", AttributeValues.fromByteArray(commitment),
":name", AttributeValues.fromString("name"), ":name", AttributeValues.fromString(name),
":avatar", AttributeValues.fromString("avatar"))), ":avatar", AttributeValues.fromString(avatar))),
Arguments.of( Arguments.of(
new VersionedProfile("version", "name", null, null, null, null, commitment), new VersionedProfile(version, name, null, null, null, null, commitment),
Map.of( Map.of(
":commitment", AttributeValues.fromByteArray(commitment), ":commitment", AttributeValues.fromByteArray(commitment),
":name", AttributeValues.fromString("name"))), ":name", AttributeValues.fromString(name))),
Arguments.of( Arguments.of(
new VersionedProfile("version", null, null, null, null, null, commitment), new VersionedProfile(version, null, null, null, null, null, commitment),
Map.of(":commitment", AttributeValues.fromByteArray(commitment))) Map.of(":commitment", AttributeValues.fromByteArray(commitment)))
); );
} }
private static String generateRandomBase64FromByteArray(final int byteArrayLength) {
return Base64.getEncoder().encodeToString(generateRandomByteArray(byteArrayLength));
}
private static byte[] generateRandomByteArray(final int length) {
byte[] byteArray = new byte[length];
new Random().nextBytes(byteArray);
return byteArray;
}
} }

View File

@ -122,34 +122,4 @@ abstract class SingleUsePreKeyStoreTest<K extends PreKey<?>> {
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join());
assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join());
} }
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> getPreKeyStore().extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
} }

View File

@ -6,10 +6,16 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -63,4 +69,34 @@ public class AttributeValuesTest {
final Map<String, AttributeValue> item = Map.of("key", AttributeValue.builder().nul(true).build()); final Map<String, AttributeValue> item = Map.of("key", AttributeValue.builder().nul(true).build());
assertNull(AttributeValues.getUUID(item, "key", null)); assertNull(AttributeValues.getUUID(item, "key", null));
} }
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, AttributeValues.extractByteArray(attributeValue, "counter"));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> AttributeValues.extractByteArray(attributeValue, "counter"));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
} }