diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiers.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiers.java index 5eadc1c8b..7e5cdc3bc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiers.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiers.java @@ -10,14 +10,27 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.Util; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; -import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; -import software.amazon.awssdk.services.dynamodb.model.ReturnValue; -import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.model.BatchGetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.CancellationReason; +import software.amazon.awssdk.services.dynamodb.model.KeysAndAttributes; +import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; +import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; +import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException; +import software.amazon.awssdk.services.dynamodb.model.Update; /** * Manages a global, persistent mapping of phone numbers to phone number identifiers regardless of whether those @@ -35,8 +48,13 @@ public class PhoneNumberIdentifiers { @VisibleForTesting static final String ATTR_PHONE_NUMBER_IDENTIFIER = "PNI"; + private static final String CONDITIONAL_CHECK_FAILED = "ConditionalCheckFailed"; + private static final Timer GET_PNI_TIMER = Metrics.timer(name(PhoneNumberIdentifiers.class, "get")); private static final Timer SET_PNI_TIMER = Metrics.timer(name(PhoneNumberIdentifiers.class, "set")); + private static final int MAX_RETRIES = 10; + + private static final Logger logger = LoggerFactory.getLogger(PhoneNumberIdentifiers.class); public PhoneNumberIdentifiers(final DynamoDbAsyncClient dynamoDbClient, final String tableName) { this.dynamoDbClient = dynamoDbClient; @@ -44,38 +62,159 @@ public class PhoneNumberIdentifiers { } /** - * Returns the phone number identifier (PNI) associated with the given phone number. + * Returns the phone number identifier (PNI) associated with the given phone number. If one doesn't exist, it is + * created. * * @param phoneNumber the phone number for which to retrieve a phone number identifier * @return the phone number identifier associated with the given phone number */ public CompletableFuture getPhoneNumberIdentifier(final String phoneNumber) { - final Timer.Sample sample = Timer.start(); + // Each e164 phone number string represents a potential equivalence class e164s that represent the same number. If + // this is a new phone number, we'll want to set all the numbers in the equivalence class to the same PNI + final List allPhoneNumberForms = Util.getAlternateForms(phoneNumber); - return dynamoDbClient.getItem(GetItemRequest.builder() - .tableName(tableName) - .key(Map.of(KEY_E164, AttributeValues.fromString(phoneNumber))) - .projectionExpression(ATTR_PHONE_NUMBER_IDENTIFIER) - .build()) - .thenCompose(response -> response.hasItem() - ? CompletableFuture.completedFuture(AttributeValues.getUUID(response.item(), ATTR_PHONE_NUMBER_IDENTIFIER, null)) - : generatePhoneNumberIdentifierIfNotExists(phoneNumber)) - .whenComplete((ignored, throwable) -> sample.stop(GET_PNI_TIMER)); + return retry(MAX_RETRIES, TransactionConflictException.class, () -> fetchPhoneNumbers(allPhoneNumberForms) + .thenCompose(mappings -> setPniIfRequired(phoneNumber, allPhoneNumberForms, mappings))); } @VisibleForTesting - CompletableFuture generatePhoneNumberIdentifierIfNotExists(final String phoneNumber) { - final Timer.Sample sample = Timer.start(); + static CompletableFuture retry( + final int numRetries, final Class exceptionToRetry, final Supplier> supplier) { + return supplier.get().exceptionallyCompose(ExceptionUtils.exceptionallyHandler(exceptionToRetry, e -> { + if (numRetries - 1 <= 0) { + throw ExceptionUtils.wrap(e); + } + return retry(numRetries - 1, exceptionToRetry, supplier); + })); + } - return dynamoDbClient.updateItem(UpdateItemRequest.builder() - .tableName(tableName) - .key(Map.of(KEY_E164, AttributeValues.fromString(phoneNumber))) - .updateExpression("SET #pni = if_not_exists(#pni, :pni)") - .expressionAttributeNames(Map.of("#pni", ATTR_PHONE_NUMBER_IDENTIFIER)) - .expressionAttributeValues(Map.of(":pni", AttributeValues.fromUUID(UUID.randomUUID()))) - .returnValues(ReturnValue.ALL_NEW) + /** + * Determine what PNI to set for the provided numbers, and set them if required + * + * @param phoneNumber The original e164 the operation is for + * @param allPhoneNumberForms The e164s to set. The first e164 in this list should be phoneNumber + * @param existingAssociations The current associations of allPhoneNumberForms in the table + * @return The PNI now associated with phoneNumber + */ + @VisibleForTesting + CompletableFuture setPniIfRequired( + final String phoneNumber, + final List allPhoneNumberForms, + Map existingAssociations) { + if (!phoneNumber.equals(allPhoneNumberForms.getFirst())) { + throw new IllegalArgumentException("allPhoneNumberForms must start with the target phoneNumber"); + } + + if (existingAssociations.containsKey(phoneNumber)) { + // If the provided phone number already has an association, just return that + return CompletableFuture.completedFuture(existingAssociations.get(phoneNumber)); + } + + if (allPhoneNumberForms.size() == 1 || existingAssociations.isEmpty()) { + // Easy case, if we're the only phone number in our equivalence class or there are no existing associations, + // we can just make an association for a new PNI + return setPni(phoneNumber, allPhoneNumberForms, UUID.randomUUID()); + } + + // Otherwise, what members of the equivalence class have a PNI association? + final Map> byPni = existingAssociations.entrySet().stream().collect(Collectors.groupingBy( + Map.Entry::getValue, + Collectors.mapping(Map.Entry::getKey, Collectors.toList()))); + + // Usually there should be only a single PNI associated with the equivalence class, but it's possible there's + // more. This could only happen if an equivalence class had more than two numbers, and had accumulated 2 unique + // PNI associations before they were merged into a single class. In this case we've picked one of those pnis + // arbitrarily (according to their ordering as returned by getAlternateForms) + final UUID existingPni = allPhoneNumberForms.stream() + .filter(existingAssociations::containsKey) + .findFirst() + .map(existingAssociations::get) + .orElseThrow(() -> new IllegalStateException("Previously checked that a mapping existed")); + + if (byPni.size() > 1) { + logger.warn("More than one PNI existed in the PNI table for the numbers that map to {}. " + + "Arbitrarily picking {} to be the representative PNI for the numbers without PNI associations", + phoneNumber, existingPni); + } + + // Find all the unmapped phoneNumbers and set them to the PNI we chose from another member of the equivalence class + final List unmappedNumbers = allPhoneNumberForms.stream() + .filter(number -> !existingAssociations.containsKey(number)) + .toList(); + + return setPni(phoneNumber, unmappedNumbers, existingPni); + } + + + /** + * Attempt to associate phoneNumbers with the provided pni. If any of the phoneNumbers have an existing association + * that is not the target pni, no update will occur. If the first phoneNumber in phoneNumbers has an existing + * association, it will be returned, otherwise an exception will be thrown. + * + * @param originalPhoneNumber The original e164 the operation is for + * @param allPhoneNumberForms The e164s to set. The first e164 in this list should be originalPhoneNumber + * @param pni The PNI to set + * @return The provided PNI if the update occurred, or the existing PNI associated with originalPhoneNumber + */ + @VisibleForTesting + CompletableFuture setPni(final String originalPhoneNumber, final List allPhoneNumberForms, + final UUID pni) { + if (!originalPhoneNumber.equals(allPhoneNumberForms.getFirst())) { + throw new IllegalArgumentException("allPhoneNumberForms must start with the target phoneNumber"); + } + + final Timer.Sample sample = Timer.start(); + final List transactWriteItems = allPhoneNumberForms + .stream() + .map(phoneNumber -> TransactWriteItem.builder() + .update(Update.builder() + .tableName(tableName) + .key(Map.of(KEY_E164, AttributeValues.fromString(phoneNumber))) + .updateExpression("SET #pni = :pni") + // It's possible we're racing with someone else to update, but both of us selected the same PNI because + // an equivalent number already had it. That's fine, as long as the association happens. + .conditionExpression("attribute_not_exists(#pni) OR #pni = :pni") + .expressionAttributeNames(Map.of("#pni", ATTR_PHONE_NUMBER_IDENTIFIER)) + .expressionAttributeValues(Map.of(":pni", AttributeValues.fromUUID(pni))) + .returnValuesOnConditionCheckFailure(ReturnValuesOnConditionCheckFailure.ALL_OLD) + .build()).build()) + .toList(); + + return dynamoDbClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems(transactWriteItems) .build()) - .thenApply(response -> AttributeValues.getUUID(response.attributes(), ATTR_PHONE_NUMBER_IDENTIFIER, null)) + .thenApply(ignored -> pni) + .exceptionally(ExceptionUtils.exceptionallyHandler(TransactionCanceledException.class, e -> { + if (e.hasCancellationReasons()) { + // Get the cancellation reason for the number that we were primarily trying to associate with a PNI + final CancellationReason cancelReason = e.cancellationReasons().getFirst(); + if (CONDITIONAL_CHECK_FAILED.equals(cancelReason.code())) { + // Someone else beat us to the update, use the PNI they set. + return AttributeValues.getUUID(cancelReason.item(), ATTR_PHONE_NUMBER_IDENTIFIER, null); + } + } + throw e; + })) .whenComplete((ignored, throwable) -> sample.stop(SET_PNI_TIMER)); } + + @VisibleForTesting + CompletableFuture> fetchPhoneNumbers(List phoneNumbers) { + final Timer.Sample sample = Timer.start(); + return dynamoDbClient.batchGetItem( + BatchGetItemRequest.builder().requestItems(Map.of(tableName, KeysAndAttributes.builder() + // If we have a stale value, the subsequent conditional update will fail + .consistentRead(false) + .projectionExpression("#number,#pni") + .expressionAttributeNames(Map.of("#number", KEY_E164, "#pni", ATTR_PHONE_NUMBER_IDENTIFIER)) + .keys(phoneNumbers.stream() + .map(number -> Map.of(KEY_E164, AttributeValues.fromString(number))) + .toArray(Map[]::new)) + .build())) + .build()) + .thenApply(batchResponse -> batchResponse.responses().get(tableName).stream().collect(Collectors.toMap( + item -> AttributeValues.getString(item, KEY_E164, null), + item -> AttributeValues.getUUID(item, ATTR_PHONE_NUMBER_IDENTIFIER, null)))) + .whenComplete((ignored, throwable) -> sample.stop(GET_PNI_TIMER)); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java index c9ea4d80b..d17a3d880 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/Util.java @@ -114,7 +114,8 @@ public class Util { * * @param number the e164-formatted phone number for which to find equivalent forms * - * @return a list of phone numbers equivalent to the given phone number, including the given number + * @return a list of phone numbers equivalent to the given phone number, including the given number. The given number + * will always be the first element of the list. */ public static List getAlternateForms(final String number) { try { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiersTest.java index 814cdf6e1..cfb189867 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiersTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PhoneNumberIdentifiersTest.java @@ -6,15 +6,23 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; -import java.util.Optional; +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; +import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; class PhoneNumberIdentifiersTest { @@ -42,10 +50,143 @@ class PhoneNumberIdentifiersTest { } @Test - void generatePhoneNumberIdentifierIfNotExists() { - final String number = "+18005551234"; + void generatePhoneNumberIdentifier() { + final List numbers = List.of("+18005551234", "+18005556789"); + // Should set both PNIs to a new random PNI + final UUID pni = phoneNumberIdentifiers.setPniIfRequired(numbers.getFirst(), numbers, Collections.emptyMap()).join(); - assertEquals(phoneNumberIdentifiers.generatePhoneNumberIdentifierIfNotExists(number).join(), - phoneNumberIdentifiers.generatePhoneNumberIdentifierIfNotExists(number).join()); + assertEquals(pni, phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.getFirst()).join()); + assertEquals(pni, phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.getLast()).join()); + } + + @Test + void generatePhoneNumberIdentifierOneFormExists() { + final String firstNumber = "+18005551234"; + final String secondNumber = "+18005556789"; + final String thirdNumber = "+1800555456"; + final List allNumbers = List.of(firstNumber, secondNumber, thirdNumber); + + // Set one member of the "same" numbers to a new PNI + final UUID pni = phoneNumberIdentifiers.getPhoneNumberIdentifier(secondNumber).join(); + + final Map existingAssociations = phoneNumberIdentifiers.fetchPhoneNumbers(allNumbers).join(); + assertEquals(Map.of(secondNumber, pni), existingAssociations); + + assertEquals(pni, phoneNumberIdentifiers.setPniIfRequired(firstNumber, allNumbers, existingAssociations).join()); + + for (String number : allNumbers) { + assertEquals(pni, phoneNumberIdentifiers.getPhoneNumberIdentifier(number).join()); + } + } + + @Test + void getPhoneNumberIdentifierExistingMapping() { + final String newFormatBeninE164 = PhoneNumberUtil.getInstance() + .format(PhoneNumberUtil.getInstance().getExampleNumber("BJ"), PhoneNumberUtil.PhoneNumberFormat.E164); + + final String oldFormatBeninE164 = newFormatBeninE164.replaceFirst("01", ""); + final UUID oldFormatPni = phoneNumberIdentifiers.getPhoneNumberIdentifier(oldFormatBeninE164).join(); + final UUID newFormatPni = phoneNumberIdentifiers.getPhoneNumberIdentifier(newFormatBeninE164).join(); + assertEquals(oldFormatPni, newFormatPni); + } + + @Test + void conflictingExistingPnis() { + final String firstNumber = "+18005551234"; + final String secondNumber = "+18005556789"; + + final UUID firstPni = phoneNumberIdentifiers.getPhoneNumberIdentifier(firstNumber).join(); + final UUID secondPni = phoneNumberIdentifiers.getPhoneNumberIdentifier(secondNumber).join(); + assertNotEquals(firstPni, secondPni); + + assertEquals( + firstPni, + phoneNumberIdentifiers.setPniIfRequired( + firstNumber, List.of(firstNumber, secondNumber), + phoneNumberIdentifiers.fetchPhoneNumbers(List.of(firstNumber, secondNumber)).join()).join()); + assertEquals( + secondPni, + phoneNumberIdentifiers.setPniIfRequired( + secondNumber, List.of(secondNumber, firstNumber), + phoneNumberIdentifiers.fetchPhoneNumbers(List.of(firstNumber, secondNumber)).join()).join()); + } + + @Test + void conflictOnOriginalNumber() { + final List numbers = List.of("+18005551234", "+18005556789"); + // Stale view of database where both numbers have no PNI + final Map existingAssociations = Collections.emptyMap(); + + // Both numbers have different PNIs + final UUID pni1 = phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.getFirst()).join(); + final UUID pni2 = phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.getLast()).join(); + assertNotEquals(pni1, pni2); + + // Should conflict and find that we now have a PNI + assertEquals(pni1, phoneNumberIdentifiers.setPniIfRequired(numbers.getFirst(), numbers, existingAssociations).join()); + } + + @Test + void conflictOnAlternateNumber() { + final List numbers = List.of("+18005551234", "+18005556789"); + // Stale view of database where both numbers have no PNI + final Map existingAssociations = Collections.emptyMap(); + + // the alternate number has a PNI added + phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.getLast()).join(); + + // Should conflict and fail + CompletableFutureTestUtil.assertFailsWithCause( + TransactionCanceledException.class, + phoneNumberIdentifiers.setPniIfRequired(numbers.getFirst(), numbers, existingAssociations)); + } + + @Test + void multipleAssociations() { + final List numbers = List.of("+18005550000", "+18005551111", "+18005552222", "+18005553333", "+1800555444"); + + // Set pni1={number1, number2}, pni2={number3}, number0 and number 4 unset + final UUID pni1 = phoneNumberIdentifiers.setPniIfRequired(numbers.get(1), numbers.subList(1, 3), + Collections.emptyMap()).join(); + final UUID pni2 = phoneNumberIdentifiers.setPniIfRequired(numbers.get(3), List.of(numbers.get(3)), + Collections.emptyMap()).join(); + + final Map existingAssociations = phoneNumberIdentifiers.fetchPhoneNumbers(numbers).join(); + assertEquals(existingAssociations, Map.of(numbers.get(1), pni1, numbers.get(2), pni1, numbers.get(3), pni2)); + + // The unmapped phone numbers should map to the arbitrarily selected PNI (which is selected based on the order + // of the numbers) + assertEquals(pni1, phoneNumberIdentifiers.setPniIfRequired(numbers.get(0), numbers, existingAssociations).join()); + assertEquals(pni1, phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.get(0)).join()); + assertEquals(pni1, phoneNumberIdentifiers.getPhoneNumberIdentifier(numbers.get(4)).join()); + } + + private static class FailN implements Supplier> { + final AtomicInteger numFails; + + FailN(final int numFails) { + this.numFails = new AtomicInteger(numFails); + } + + @Override + public CompletableFuture get() { + if (numFails.getAndDecrement() == 0) { + return CompletableFuture.completedFuture(7); + } + return CompletableFuture.failedFuture(new IOException("test")); + } + } + + @Test + void testRetry() { + assertEquals(7, PhoneNumberIdentifiers.retry(10, IOException.class, new FailN(9)).join()); + + CompletableFutureTestUtil.assertFailsWithCause( + IOException.class, + PhoneNumberIdentifiers.retry(10, IOException.class, new FailN(10))); + + CompletableFutureTestUtil.assertFailsWithCause( + IOException.class, + PhoneNumberIdentifiers.retry(10, RuntimeException.class, new FailN(1))); } }