Prepare to read pre-keys stored as byte arrays

This commit is contained in:
Jon Chambers 2023-05-17 17:25:18 -04:00 committed by Jon Chambers
parent 300ac16cf1
commit d3e0ba6d44
2 changed files with 113 additions and 36 deletions

View File

@ -15,6 +15,7 @@ import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.Counter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -22,6 +23,8 @@ import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.AttributeValues;
@ -380,11 +383,32 @@ public class Keys extends AbstractDynamoDbStore {
private PreKey getPreKeyFromItem(Map<String, AttributeValue> item) {
final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8);
final String publicKey = Base64.getEncoder().encodeToString(extractByteArray(item.get(KEY_PUBLIC_KEY)));
if (item.containsKey(KEY_SIGNATURE)) {
// All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored
// in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys.
return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s());
final String signature = Base64.getEncoder().encodeToString(extractByteArray(item.get(KEY_SIGNATURE)));
return new SignedPreKey(keyId, publicKey, signature);
}
return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s());
return new PreKey(keyId, publicKey);
}
/**
* Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string.
*
* @param attributeValue the {@code AttributeValue} from which to extract a byte array
*
* @return the byte array represented by the given {@code AttributeValue}
*/
@VisibleForTesting
static byte[] extractByteArray(final AttributeValue attributeValue) {
if (attributeValue.b() != null) {
return attributeValue.b().asByteArray();
} else if (StringUtils.isNotBlank(attributeValue.s())) {
return Base64.getDecoder().decode(attributeValue.s());
}
throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value");
}
}

View File

@ -5,29 +5,33 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.Select;
import static org.junit.jupiter.api.Assertions.*;
class KeysTest {
private Keys keys;
@ -57,41 +61,41 @@ class KeysTest {
assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(),
"Initial last-resort pre-key for an account should be missing");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Repeatedly storing same key should have no effect");
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1, "pq-public-key", "sig")), null);
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ prekeys should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, new SignedPreKey(1001, "pq-last-resort-key", "sig"));
keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001));
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on EC prekeys");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys");
assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId());
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")), null, null);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Uploading new EC prekeys should have no effect on PQ prekeys");
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key")), List.of(new SignedPreKey(2, "different-pq-public-key", "sig")), null);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting a new key should overwrite all prior keys of the same type for the given account/device");
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(4, "fourth-public-key"), new PreKey(5, "fifth-public-key")),
List.of(new SignedPreKey(6, "sixth-pq-key", "sig"), new SignedPreKey(7, "seventh-pq-key", "sig")),
new SignedPreKey(1002, "new-last-resort-key", "sig"));
List.of(generateTestPreKey(4), generateTestPreKey(5)),
List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)),
generateTestSignedPreKey(1002));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID),
"Inserting multiple new keys should overwrite all prior keys for the given account/device");
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID),
@ -104,10 +108,11 @@ class KeysTest {
void testTakeAccountAndDeviceId() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final PreKey preKey = new PreKey(1, "public-key");
final PreKey preKey = generateTestPreKey(1);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(Optional.of(preKey), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)));
final Optional<PreKey> takenKey = keys.takeEC(ACCOUNT_UUID, DEVICE_ID);
assertEquals(Optional.of(preKey), takenKey);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
}
@ -115,9 +120,9 @@ class KeysTest {
void testTakePQ() {
assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID));
final SignedPreKey preKey1 = new SignedPreKey(1, "public-key", "sig");
final SignedPreKey preKey2 = new SignedPreKey(2, "different-public-key", "sig");
final SignedPreKey preKeyLast = new SignedPreKey(1001, "last-public-key", "sig");
final SignedPreKey preKey1 = generateTestSignedPreKey(1);
final SignedPreKey preKey2 = generateTestSignedPreKey(2);
final SignedPreKey preKeyLast = generateTestSignedPreKey(1001);
keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast);
@ -139,7 +144,7 @@ class KeysTest {
assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")), List.of(new SignedPreKey(1, "public-pq-key", "sig")), null);
keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null);
assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
}
@ -147,14 +152,14 @@ class KeysTest {
@Test
void testDeleteByAccount() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@ -176,14 +181,14 @@ class KeysTest {
@Test
void testDeleteByAccountAndDevice() {
keys.store(ACCOUNT_UUID, DEVICE_ID,
List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")),
List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")),
new SignedPreKey(5, "last-pq-key", "sig"));
List.of(generateTestPreKey(1), generateTestPreKey(2)),
List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)),
generateTestSignedPreKey(5));
keys.store(ACCOUNT_UUID, DEVICE_ID + 1,
List.of(new PreKey(6, "public-key-for-different-device")),
List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")),
new SignedPreKey(8, "last-pq-key-for-different-device", "sig"));
List.of(generateTestPreKey(6)),
List.of(generateTestSignedPreKey(7)),
generateTestSignedPreKey(8));
assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID));
assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID));
@ -251,4 +256,52 @@ class KeysTest {
AttributeValue got = Keys.getSortKeyPrefix(123);
assertArrayEquals(new byte[]{0, 0, 0, 0, 0, 0, 0, 123}, got.b().asByteArray());
}
@ParameterizedTest
@MethodSource
void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) {
assertArrayEquals(expectedByteArray, Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArray() {
final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc=");
return Stream.of(
Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key),
Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key)
);
}
@ParameterizedTest
@MethodSource
void extractByteArrayIllegalArgument(final AttributeValue attributeValue) {
assertThrows(IllegalArgumentException.class, () -> Keys.extractByteArray(attributeValue));
}
private static Stream<Arguments> extractByteArrayIllegalArgument() {
return Stream.of(
Arguments.of(AttributeValue.fromN("12")),
Arguments.of(AttributeValue.fromS("")),
Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎"))
);
}
private static PreKey generateTestPreKey(final long keyId) {
final byte[] key = new byte[32];
new SecureRandom().nextBytes(key);
return new PreKey(keyId, Base64.getEncoder().encodeToString(key));
}
private static SignedPreKey generateTestSignedPreKey(final long keyId) {
final byte[] key = new byte[32];
final byte[] signature = new byte[32];
final SecureRandom secureRandom = new SecureRandom();
secureRandom.nextBytes(key);
secureRandom.nextBytes(signature);
return new SignedPreKey(keyId, Base64.getEncoder().encodeToString(key), Base64.getEncoder().encodeToString(signature));
}
}