From d3e0ba6d44d36611b9e7f7aeffefa28757af3e7e Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 17 May 2023 17:25:18 -0400 Subject: [PATCH] Prepare to read pre-keys stored as byte arrays --- .../textsecuregcm/storage/Keys.java | 28 +++- .../textsecuregcm/storage/KeysTest.java | 121 +++++++++++++----- 2 files changed, 113 insertions(+), 36 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index 4c92c0377..9676aa5b8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -15,6 +15,7 @@ import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Counter; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; @@ -22,6 +23,8 @@ import java.util.UUID; import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.Nullable; + +import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; @@ -380,11 +383,32 @@ public class Keys extends AbstractDynamoDbStore { private PreKey getPreKeyFromItem(Map item) { final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + final String publicKey = Base64.getEncoder().encodeToString(extractByteArray(item.get(KEY_PUBLIC_KEY))); + if (item.containsKey(KEY_SIGNATURE)) { // All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored // in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys. - return new SignedPreKey(keyId, item.get(KEY_PUBLIC_KEY).s(), item.get(KEY_SIGNATURE).s()); + final String signature = Base64.getEncoder().encodeToString(extractByteArray(item.get(KEY_SIGNATURE))); + return new SignedPreKey(keyId, publicKey, signature); } - return new PreKey(keyId, item.get(KEY_PUBLIC_KEY).s()); + return new PreKey(keyId, publicKey); + } + + /** + * Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string. + * + * @param attributeValue the {@code AttributeValue} from which to extract a byte array + * + * @return the byte array represented by the given {@code AttributeValue} + */ + @VisibleForTesting + static byte[] extractByteArray(final AttributeValue attributeValue) { + if (attributeValue.b() != null) { + return attributeValue.b().asByteArray(); + } else if (StringUtils.isNotBlank(attributeValue.s())) { + return Base64.getDecoder().decode(attributeValue.s()); + } + + throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java index cb6b4f8a0..c6f84d179 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java @@ -5,29 +5,33 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertIterableEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - +import java.security.SecureRandom; +import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.stream.Stream; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.util.AttributeValues; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.Select; +import static org.junit.jupiter.api.Assertions.*; + class KeysTest { private Keys keys; @@ -57,41 +61,41 @@ class KeysTest { assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), "Initial last-resort pre-key for an account should be missing"); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Repeatedly storing same key should have no effect"); - keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(new SignedPreKey(1, "pq-public-key", "sig")), null); + keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ prekeys should have no effect on EC prekeys"); assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, new SignedPreKey(1001, "pq-last-resort-key", "sig")); + keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001)); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key")), null, null); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Uploading new EC prekeys should have no effect on PQ prekeys"); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key")), List.of(new SignedPreKey(2, "different-pq-public-key", "sig")), null); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(new PreKey(4, "fourth-public-key"), new PreKey(5, "fifth-public-key")), - List.of(new SignedPreKey(6, "sixth-pq-key", "sig"), new SignedPreKey(7, "seventh-pq-key", "sig")), - new SignedPreKey(1002, "new-last-resort-key", "sig")); + List.of(generateTestPreKey(4), generateTestPreKey(5)), + List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)), + generateTestSignedPreKey(1002)); assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), "Inserting multiple new keys should overwrite all prior keys for the given account/device"); assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), @@ -104,10 +108,11 @@ class KeysTest { void testTakeAccountAndDeviceId() { assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); - final PreKey preKey = new PreKey(1, "public-key"); + final PreKey preKey = generateTestPreKey(1); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); - assertEquals(Optional.of(preKey), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); + final Optional takenKey = keys.takeEC(ACCOUNT_UUID, DEVICE_ID); + assertEquals(Optional.of(preKey), takenKey); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); } @@ -115,9 +120,9 @@ class KeysTest { void testTakePQ() { assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); - final SignedPreKey preKey1 = new SignedPreKey(1, "public-key", "sig"); - final SignedPreKey preKey2 = new SignedPreKey(2, "different-public-key", "sig"); - final SignedPreKey preKeyLast = new SignedPreKey(1001, "last-public-key", "sig"); + final SignedPreKey preKey1 = generateTestSignedPreKey(1); + final SignedPreKey preKey2 = generateTestSignedPreKey(2); + final SignedPreKey preKeyLast = generateTestSignedPreKey(1001); keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); @@ -139,7 +144,7 @@ class KeysTest { assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key")), List.of(new SignedPreKey(1, "public-pq-key", "sig")), null); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null); assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); } @@ -147,14 +152,14 @@ class KeysTest { @Test void testDeleteByAccount() { keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")), - List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")), - new SignedPreKey(5, "last-pq-key", "sig")); + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), + generateTestSignedPreKey(5)); keys.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(new PreKey(6, "public-key-for-different-device")), - List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")), - new SignedPreKey(8, "last-pq-key-for-different-device", "sig")); + List.of(generateTestPreKey(6)), + List.of(generateTestSignedPreKey(7)), + generateTestSignedPreKey(8)); assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); @@ -176,14 +181,14 @@ class KeysTest { @Test void testDeleteByAccountAndDevice() { keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key")), - List.of(new SignedPreKey(3, "public-pq-key", "sig"), new SignedPreKey(4, "different-pq-key", "sig")), - new SignedPreKey(5, "last-pq-key", "sig")); + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), + generateTestSignedPreKey(5)); keys.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(new PreKey(6, "public-key-for-different-device")), - List.of(new SignedPreKey(7, "public-pq-key-for-different-device", "sig")), - new SignedPreKey(8, "last-pq-key-for-different-device", "sig")); + List.of(generateTestPreKey(6)), + List.of(generateTestSignedPreKey(7)), + generateTestSignedPreKey(8)); assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); @@ -251,4 +256,52 @@ class KeysTest { AttributeValue got = Keys.getSortKeyPrefix(123); assertArrayEquals(new byte[]{0, 0, 0, 0, 0, 0, 0, 123}, got.b().asByteArray()); } + + @ParameterizedTest + @MethodSource + void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) { + assertArrayEquals(expectedByteArray, Keys.extractByteArray(attributeValue)); + } + + private static Stream extractByteArray() { + final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc="); + + return Stream.of( + Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key) + ); + } + + @ParameterizedTest + @MethodSource + void extractByteArrayIllegalArgument(final AttributeValue attributeValue) { + assertThrows(IllegalArgumentException.class, () -> Keys.extractByteArray(attributeValue)); + } + + private static Stream extractByteArrayIllegalArgument() { + return Stream.of( + Arguments.of(AttributeValue.fromN("12")), + Arguments.of(AttributeValue.fromS("")), + Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎")) + ); + } + + private static PreKey generateTestPreKey(final long keyId) { + final byte[] key = new byte[32]; + new SecureRandom().nextBytes(key); + + return new PreKey(keyId, Base64.getEncoder().encodeToString(key)); + } + + private static SignedPreKey generateTestSignedPreKey(final long keyId) { + final byte[] key = new byte[32]; + final byte[] signature = new byte[32]; + + final SecureRandom secureRandom = new SecureRandom(); + secureRandom.nextBytes(key); + secureRandom.nextBytes(signature); + + return new SignedPreKey(keyId, Base64.getEncoder().encodeToString(key), Base64.getEncoder().encodeToString(signature)); + } }