diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java index 7d1f2a119..daa870b8c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java @@ -7,17 +7,27 @@ package org.whispersystems.textsecuregcm.storage; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore { + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final String tableName; + public RepeatedUseECSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { super(dynamoDbAsyncClient, tableName); + + this.dynamoDbAsyncClient = dynamoDbAsyncClient; + this.tableName = tableName; } @Override @@ -43,4 +53,21 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore throw new IllegalArgumentException(e); } } + + public CompletableFuture storeIfAbsent(final UUID identifier, final long deviceId, final ECSignedPreKey signedPreKey) { + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) + .conditionExpression("attribute_not_exists(#public_key)") + .expressionAttributeNames(Map.of("#public_key", ATTR_PUBLIC_KEY)) + .build()) + .thenApply(ignored -> true) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) { + return false; + } + + throw ExceptionUtils.wrap(throwable); + }); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java index 11c467ddc..1f6476d8c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java @@ -6,11 +6,18 @@ package org.whispersystems.textsecuregcm.storage; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import java.util.Optional; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest { @@ -39,4 +46,21 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes protected ECSignedPreKey generateSignedPreKey() { return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR); } + + @Test + void storeIfAbsent() { + final UUID identifier = UUID.randomUUID(); + final long deviceIdWithExistingKey = 1; + final long deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1; + + final ECSignedPreKey originalSignedPreKey = generateSignedPreKey(); + + keyStore.store(identifier, deviceIdWithExistingKey, originalSignedPreKey).join(); + + assertFalse(keyStore.storeIfAbsent(identifier, deviceIdWithExistingKey, generateSignedPreKey()).join()); + assertTrue(keyStore.storeIfAbsent(identifier, deviceIdWithoutExistingKey, generateSignedPreKey()).join()); + + assertEquals(Optional.of(originalSignedPreKey), keyStore.find(identifier, deviceIdWithExistingKey).join()); + assertTrue(keyStore.find(identifier, deviceIdWithoutExistingKey).join().isPresent()); + } }