From 47cc7fd615fe5d159953190b546e199aac955df9 Mon Sep 17 00:00:00 2001 From: Sergey Skrobotov Date: Fri, 2 Jun 2023 10:15:09 -0700 Subject: [PATCH] username links API --- .../controllers/AccountController.java | 120 ++++++++- .../entities/EncryptedUsername.java | 12 + .../entities/UsernameLinkHandle.java | 12 + .../textsecuregcm/limits/RateLimiters.java | 2 + .../textsecuregcm/storage/Account.java | 44 ++- .../textsecuregcm/storage/Accounts.java | 66 ++++- .../storage/AccountsManager.java | 107 ++++---- .../util/UsernameHashZkProofVerifier.java | 7 +- .../controllers/AccountControllerTest.java | 252 ++++++++++++++---- .../textsecuregcm/storage/AccountsTest.java | 112 +++++++- .../storage/DynamoDbExtensionSchema.java | 28 +- .../tests/util/AccountsHelper.java | 1 + .../textsecuregcm/util/MockUtils.java | 32 ++- 13 files changed, 653 insertions(+), 142 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedUsername.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/UsernameLinkHandle.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 0730168db..9d29566d8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -20,6 +20,8 @@ import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.responses.ApiResponse; import java.io.IOException; import java.security.SecureRandom; import java.time.Clock; @@ -32,6 +34,7 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletionException; +import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; import javax.validation.Valid; import javax.validation.constraints.NotNull; @@ -77,6 +80,7 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest; import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.DeviceName; +import org.whispersystems.textsecuregcm.entities.EncryptedUsername; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; @@ -85,6 +89,7 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.UsernameHashResponse; +import org.whispersystems.textsecuregcm.entities.UsernameLinkHandle; import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -421,7 +426,7 @@ public class AccountController { } if (availableForTransfer.orElse(false) && existingAccount.map(Account::isTransferSupported).orElse(false)) { - throw new WebApplicationException(Response.status(409).build()); + throw new WebApplicationException(Status.CONFLICT); } rateLimiters.getVerifyLimiter().clear(number); @@ -699,7 +704,8 @@ public class AccountController { @DELETE @Path("/username_hash") @Produces(MediaType.APPLICATION_JSON) - public void deleteUsernameHash(@Auth AuthenticatedAccount auth) { + public void deleteUsernameHash(final @Auth AuthenticatedAccount auth) { + clearUsernameLink(auth.getAccount()); accounts.clearUsernameHash(auth.getAccount()); } @@ -736,9 +742,10 @@ public class AccountController { @Path("/username_hash/confirm") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) - public UsernameHashResponse confirmUsernameHash(@Auth AuthenticatedAccount auth, - @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent, - @NotNull @Valid ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException { + public UsernameHashResponse confirmUsernameHash( + @Auth final AuthenticatedAccount auth, + @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent, + @NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException { rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid()); try { @@ -747,6 +754,11 @@ public class AccountController { throw new WebApplicationException(Response.status(422).build()); } + // Whenever a valid request for a username change arrives, + // we're making sure to clear username link. This may happen before and username changes are written to the db + // but verifying zk proof means that request itself is valid from the client's perspective + clearUsernameLink(auth.getAccount()); + try { final Account account = accounts.confirmReservedUsernameHash(auth.getAccount(), confirmRequest.usernameHash()); return account @@ -796,6 +808,90 @@ public class AccountController { .orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND)); } + @Timed + @PUT + @Path("/username_link") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + @Operation( + summary = "Set username link", + description = """ + Authenticated endpoint. For the given encrypted username generates a username link handle. + Username link handle could be used to lookup the encrypted username. + An account can only have one username link at a time. Calling this endpoint will reset previously stored + encrypted username and deactivate previous link handle. + """ + ) + @ApiResponse(responseCode = "200", description = "Username Link updated successfully.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "401", description = "Account authentication check failed.") + @ApiResponse(responseCode = "409", description = "Username is not set for the account.") + @ApiResponse(responseCode = "422", description = "Invalid request format.") + @ApiResponse(responseCode = "429", description = "Ratelimited.") + public UsernameLinkHandle updateUsernameLink( + @Auth final AuthenticatedAccount auth, + @NotNull @Valid final EncryptedUsername encryptedUsername) throws RateLimitExceededException { + // check ratelimiter for username link operations + rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid()); + + // check if username hash is set for the account + if (auth.getAccount().getUsernameHash().isEmpty()) { + throw new WebApplicationException(Status.CONFLICT); + } + + final UUID usernameLinkHandle = UUID.randomUUID(); + updateUsernameLink(auth.getAccount(), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue()); + return new UsernameLinkHandle(usernameLinkHandle); + } + + @Timed + @DELETE + @Path("/username_link") + @Operation( + summary = "Delete username link", + description = """ + Authenticated endpoint. Deletes username link for the given account: previously store encrypted username is deleted + and username link handle is deactivated. + """ + ) + @ApiResponse(responseCode = "204", description = "Username Link successfully deleted.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "401", description = "Account authentication check failed.") + @ApiResponse(responseCode = "429", description = "Ratelimited.") + public void deleteUsernameLink(@Auth final AuthenticatedAccount auth) throws RateLimitExceededException { + // check ratelimiter for username link operations + rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid()); + clearUsernameLink(auth.getAccount()); + } + + @Timed + @GET + @Path("/username_link/{uuid}") + @Produces(MediaType.APPLICATION_JSON) + @RateLimitedByIp(RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP) + @Operation( + summary = "Lookup username link", + description = """ + Enforced unauthenticated endpoint. For the given username link handle, looks up the database for an associated encrypted username. + If found, encrypted username is returned, otherwise responds with 404 Not Found. + """ + ) + @ApiResponse(responseCode = "200", description = "Username link with the given handle was found.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "404", description = "Username link was not found for the given handle.") + @ApiResponse(responseCode = "422", description = "Invalid request format.") + @ApiResponse(responseCode = "429", description = "Ratelimited.") + public EncryptedUsername lookupUsernameLink( + @Auth Optional authenticatedAccount, + @PathParam("uuid") final UUID usernameLinkHandle) { + final Optional maybeEncryptedUsername = accounts.getByUsernameLinkHandle(usernameLinkHandle) + .flatMap(Account::getEncryptedUsername); + if (authenticatedAccount.isPresent()) { + throw new ForbiddenException("must not use authenticated connection for connection graph revealing operations"); + } + if (maybeEncryptedUsername.isEmpty()) { + throw new WebApplicationException(Status.NOT_FOUND); + } + return new EncryptedUsername(maybeEncryptedUsername.get()); + } + @HEAD @Path("/account/{uuid}") @RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE) @@ -915,6 +1011,20 @@ public class AccountController { } } + private void clearUsernameLink(final Account account) { + updateUsernameLink(account, null, null); + } + + private void updateUsernameLink( + final Account account, + @Nullable final UUID usernameLinkHandle, + @Nullable final byte[] encryptedUsername) { + if ((encryptedUsername == null) ^ (usernameLinkHandle == null)) { + throw new IllegalStateException("Both or neither arguments must be null"); + } + accounts.update(account, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername)); + } + private void rethrowRateLimitException(final CompletionException completionException) throws RateLimitExceededException { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedUsername.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedUsername.java new file mode 100644 index 000000000..0e913dd3d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/EncryptedUsername.java @@ -0,0 +1,12 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Size; + +public record EncryptedUsername(@NotNull @Size(min = 1, max = 128) byte[] usernameLinkEncryptedValue) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/UsernameLinkHandle.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/UsernameLinkHandle.java new file mode 100644 index 000000000..0a4114301 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/UsernameLinkHandle.java @@ -0,0 +1,12 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import java.util.UUID; +import javax.validation.constraints.NotNull; + +public record UsernameLinkHandle(@NotNull UUID usernameLinkHandle) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index ca42dfd67..ec4c5739d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -37,6 +37,8 @@ public class RateLimiters extends BaseRateLimiters { USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), + USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))), + USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofMillis(60))), REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofMillis(500))), VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofMillis(500))), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index d516d4360..300b97454 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.storage; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import java.time.Clock; import java.time.Instant; import java.util.ArrayList; @@ -16,8 +18,6 @@ import java.util.Optional; import java.util.UUID; import java.util.function.Predicate; import javax.annotation.Nullable; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; @@ -54,6 +54,14 @@ public class Account { @Nullable private byte[] reservedUsernameHash; + @JsonIgnore + @Nullable + private UUID usernameLinkHandle; + + @JsonProperty("eu") + @Nullable + private byte[] encryptedUsername; + @JsonProperty private List devices = new ArrayList<>(); @@ -162,6 +170,38 @@ public class Account { this.reservedUsernameHash = reservedUsernameHash; } + @Nullable + public UUID getUsernameLinkHandle() { + requireNotStale(); + return usernameLinkHandle; + } + + public Optional getEncryptedUsername() { + requireNotStale(); + return Optional.ofNullable(encryptedUsername); + } + + public void setUsernameLinkDetails(@Nullable final UUID usernameLinkHandle, @Nullable final byte[] encryptedUsername) { + requireNotStale(); + if ((usernameLinkHandle == null) ^ (encryptedUsername == null)) { + throw new IllegalArgumentException("Both or neither arguments must be null"); + } + if (usernameHash == null && encryptedUsername != null) { + throw new IllegalArgumentException("usernameHash field must be set to store username link"); + } + this.encryptedUsername = encryptedUsername; + this.usernameLinkHandle = usernameLinkHandle; + } + + /* + * This method is intentionally left package-private so that it's only used + * when Account is read from DB + */ + void setUsernameLinkHandle(@Nullable final UUID usernameLinkHandle) { + requireNotStale(); + this.usernameLinkHandle = usernameLinkHandle; + } + public void addDevice(Device device) { requireNotStale(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index e26d62076..88212fde8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.storage; @@ -45,6 +45,8 @@ import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.Put; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; @@ -68,6 +70,7 @@ public class Accounts extends AbstractDynamoDbStore { private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update")); private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber")); private static final Timer GET_BY_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameHash")); + private static final Timer GET_BY_USERNAME_LINK_HANDLE_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameLinkHandle")); private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni")); private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid")); private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom")); @@ -82,6 +85,8 @@ public class Accounts extends AbstractDynamoDbStore { static final String KEY_ACCOUNT_UUID = "U"; // uuid, attribute on account table, primary key for PNI table static final String ATTR_PNI_UUID = "PNI"; + // uuid of the current username link or null + static final String ATTR_USERNAME_LINK_UUID = "UL"; // phone number static final String ATTR_ACCOUNT_E164 = "P"; // account, serialized to JSON @@ -99,6 +104,8 @@ public class Accounts extends AbstractDynamoDbStore { // time to live; number static final String ATTR_TTL = "TTL"; + static final String USERNAME_LINK_TO_UUID_INDEX = "ul_to_u"; + private final Clock clock; private final DynamoDbAsyncClient asyncClient; @@ -525,20 +532,28 @@ public class Accounts extends AbstractDynamoDbStore { ":version", AttributeValues.fromInt(account.getVersion()), ":version_increment", AttributeValues.fromInt(1))); - final String updateExpression; + final StringBuilder updateExpressionBuilder = new StringBuilder("SET #data = :data, #cds = :cds"); if (account.getUnidentifiedAccessKey().isPresent()) { // if it's present in the account, also set the uak attrNames.put("#uak", ATTR_UAK); attrValues.put(":uak", AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get())); - updateExpression = "SET #data = :data, #cds = :cds, #uak = :uak ADD #version :version_increment"; - } else { - updateExpression = "SET #data = :data, #cds = :cds ADD #version :version_increment"; + updateExpressionBuilder.append(", #uak = :uak"); + } + if (account.getEncryptedUsername().isPresent() && account.getUsernameLinkHandle() != null) { + attrNames.put("#ul", ATTR_USERNAME_LINK_UUID); + attrValues.put(":ul", AttributeValues.fromUUID(account.getUsernameLinkHandle())); + updateExpressionBuilder.append(", #ul = :ul"); + } + updateExpressionBuilder.append(" ADD #version :version_increment"); + if (account.getEncryptedUsername().isEmpty() || account.getUsernameLinkHandle() == null) { + attrNames.put("#ul", ATTR_USERNAME_LINK_UUID); + updateExpressionBuilder.append(" REMOVE #ul"); } updateItemRequest = UpdateItemRequest.builder() .tableName(accountsTableName) .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) - .updateExpression(updateExpression) + .updateExpression(updateExpressionBuilder.toString()) .conditionExpression("attribute_exists(#number) AND #version = :version") .expressionAttributeNames(attrNames) .expressionAttributeValues(attrValues) @@ -630,11 +645,18 @@ public class Accounts extends AbstractDynamoDbStore { ); } + @Nonnull + public Optional getByUsernameLinkHandle(final UUID usernameLinkHandle) { + return requireNonNull(GET_BY_USERNAME_LINK_HANDLE_TIMER.record(() -> + itemByGsiKey(accountsTableName, USERNAME_LINK_TO_UUID_INDEX, ATTR_USERNAME_LINK_UUID, AttributeValues.fromUUID(usernameLinkHandle)) + .map(Accounts::fromItem))); + } + @Nonnull public Optional getByAccountIdentifier(final UUID uuid) { return requireNonNull(GET_BY_UUID_TIMER.record(() -> - itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)) - .map(Accounts::fromItem))); + itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)) + .map(Accounts::fromItem))); } public void delete(final UUID uuid) { @@ -706,6 +728,33 @@ public class Accounts extends AbstractDynamoDbStore { return Optional.ofNullable(response.item()).filter(m -> !m.isEmpty()); } + @Nonnull + private Optional> itemByGsiKey(final String table, final String indexName, final String keyName, final AttributeValue keyValue) { + final QueryResponse response = db().query(QueryRequest.builder() + .tableName(table) + .indexName(indexName) + .keyConditionExpression("#gsiKey = :gsiValue") + .projectionExpression("#uuid") + .expressionAttributeNames(Map.of( + "#gsiKey", keyName, + "#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of( + ":gsiValue", keyValue)) + .build()); + + if (response.count() == 0) { + return Optional.empty(); + } + + if (response.count() > 1) { + throw new IllegalStateException("More than one row located for GSI [%s], key-value pair [%s, %s]" + .formatted(indexName, keyName, keyValue)); + } + + final AttributeValue primaryKeyValue = response.items().get(0).get(KEY_ACCOUNT_UUID); + return itemByKey(table, KEY_ACCOUNT_UUID, primaryKeyValue); + } + @Nonnull private TransactWriteItem buildAccountPut( final Account account, @@ -854,6 +903,7 @@ public class Accounts extends AbstractDynamoDbStore { account.setNumber(item.get(ATTR_ACCOUNT_E164).s(), phoneNumberIdentifierFromAttribute); account.setUuid(accountIdentifier); account.setUsernameHash(AttributeValues.getByteArray(item, ATTR_USERNAME_HASH, null)); + account.setUsernameLinkHandle(AttributeValues.getUUID(item, ATTR_USERNAME_LINK_UUID, null)); account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n())); account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE)) .map(AttributeValue::bool) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 5c755306f..6676862cb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -36,7 +36,6 @@ import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; - import org.apache.commons.lang3.ObjectUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,12 +62,14 @@ public class AccountsManager { private static final Timer updateTimer = metricRegistry.timer(name(AccountsManager.class, "update")); private static final Timer getByNumberTimer = metricRegistry.timer(name(AccountsManager.class, "getByNumber")); private static final Timer getByUsernameHashTimer = metricRegistry.timer(name(AccountsManager.class, "getByUsernameHash")); + private static final Timer getByUsernameLinkHandleTimer = metricRegistry.timer(name(AccountsManager.class, "getByUsernameLinkHandle")); private static final Timer getByUuidTimer = metricRegistry.timer(name(AccountsManager.class, "getByUuid")); private static final Timer deleteTimer = metricRegistry.timer(name(AccountsManager.class, "delete")); private static final Timer redisSetTimer = metricRegistry.timer(name(AccountsManager.class, "redisSet")); private static final Timer redisNumberGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisNumberGet")); private static final Timer redisUsernameHashGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUsernameHashGet")); + private static final Timer redisUsernameLinkHandleGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUsernameLinkHandleGet")); private static final Timer redisPniGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisPniGet")); private static final Timer redisUuidGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUuidGet")); private static final Timer redisDeleteTimer = metricRegistry.timer(name(AccountsManager.class, "redisDelete")); @@ -620,55 +621,44 @@ public class AccountsManager { }); } - public Optional getByE164(String number) { - try (Timer.Context ignored = getByNumberTimer.time()) { - Optional account = redisGetByE164(number); - - if (account.isEmpty()) { - account = accounts.getByE164(number); - account.ifPresent(this::redisSet); - } - - return account; - } + public Optional getByE164(final String number) { + return checkRedisThenAccounts( + getByNumberTimer, + () -> redisGetBySecondaryKey(getAccountMapKey(number), redisNumberGetTimer), + () -> accounts.getByE164(number) + ); } - public Optional getByPhoneNumberIdentifier(UUID pni) { - try (Timer.Context ignored = getByNumberTimer.time()) { - Optional account = redisGetByPhoneNumberIdentifier(pni); + public Optional getByPhoneNumberIdentifier(final UUID pni) { + return checkRedisThenAccounts( + getByNumberTimer, + () -> redisGetBySecondaryKey(getAccountMapKey(pni.toString()), redisPniGetTimer), + () -> accounts.getByPhoneNumberIdentifier(pni) + ); + } - if (account.isEmpty()) { - account = accounts.getByPhoneNumberIdentifier(pni); - account.ifPresent(this::redisSet); - } - - return account; - } + public Optional getByUsernameLinkHandle(final UUID usernameLinkHandle) { + return checkRedisThenAccounts( + getByUsernameLinkHandleTimer, + () -> redisGetBySecondaryKey(getAccountMapKey(usernameLinkHandle.toString()), redisUsernameLinkHandleGetTimer), + () -> accounts.getByUsernameLinkHandle(usernameLinkHandle) + ); } public Optional getByUsernameHash(final byte[] usernameHash) { - try (final Timer.Context ignored = getByUsernameHashTimer.time()) { - Optional account = redisGetByUsernameHash(usernameHash); - if (account.isEmpty()) { - account = accounts.getByUsernameHash(usernameHash); - account.ifPresent(this::redisSet); - } - - return account; - } + return checkRedisThenAccounts( + getByUsernameHashTimer, + () -> redisGetBySecondaryKey(getUsernameHashAccountMapKey(usernameHash), redisUsernameHashGetTimer), + () -> accounts.getByUsernameHash(usernameHash) + ); } - public Optional getByAccountIdentifier(UUID uuid) { - try (Timer.Context ignored = getByUuidTimer.time()) { - Optional account = redisGetByAccountIdentifier(uuid); - - if (account.isEmpty()) { - account = accounts.getByAccountIdentifier(uuid); - account.ifPresent(this::redisSet); - } - - return account; - } + public Optional getByAccountIdentifier(final UUID uuid) { + return checkRedisThenAccounts( + getByUuidTimer, + () -> redisGetByAccountIdentifier(uuid), + () -> accounts.getByAccountIdentifier(uuid) + ); } public UUID getPhoneNumberIdentifier(String e164) { @@ -758,24 +748,25 @@ public class AccountsManager { } } - private Optional redisGetByPhoneNumberIdentifier(UUID uuid) { - return redisGetBySecondaryKey(getAccountMapKey(uuid.toString()), redisPniGetTimer); + private Optional checkRedisThenAccounts( + final Timer overallTimer, + final Supplier> resolveFromRedis, + final Supplier> resolveFromAccounts) { + try (final Timer.Context ignored = overallTimer.time()) { + Optional account = resolveFromRedis.get(); + if (account.isEmpty()) { + account = resolveFromAccounts.get(); + account.ifPresent(this::redisSet); + } + return account; + } } - private Optional redisGetByE164(String e164) { - return redisGetBySecondaryKey(getAccountMapKey(e164), redisNumberGetTimer); - } - - private Optional redisGetByUsernameHash(byte[] usernameHash) { - return redisGetBySecondaryKey(getUsernameHashAccountMapKey(usernameHash), redisUsernameHashGetTimer); - } - - private Optional redisGetBySecondaryKey(String secondaryKey, Timer timer) { - try (Timer.Context ignored = timer.time()) { - final String uuid = cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey)); - - if (uuid != null) return redisGetByAccountIdentifier(UUID.fromString(uuid)); - else return Optional.empty(); + private Optional redisGetBySecondaryKey(final String secondaryKey, final Timer timer) { + try (final Timer.Context ignored = timer.time()) { + return Optional.ofNullable(cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey))) + .map(UUID::fromString) + .flatMap(this::getByAccountIdentifier); } catch (IllegalArgumentException e) { logger.warn("Deserialization error", e); return Optional.empty(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/UsernameHashZkProofVerifier.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/UsernameHashZkProofVerifier.java index 4570060be..3851510cd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/UsernameHashZkProofVerifier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/UsernameHashZkProofVerifier.java @@ -1,10 +1,15 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + package org.whispersystems.textsecuregcm.util; import org.signal.libsignal.usernames.BaseUsernameException; import org.signal.libsignal.usernames.Username; public class UsernameHashZkProofVerifier { - public void verifyProof(byte[] proof, byte[] hash) throws BaseUsernameException { + public void verifyProof(final byte[] proof, final byte[] hash) throws BaseUsernameException { Username.verifyProof(proof, hash); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index f21b1493f..15f354b31 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -54,6 +55,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nullable; import javax.ws.rs.client.Entity; +import javax.ws.rs.client.Invocation; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.apache.commons.lang3.RandomUtils; @@ -93,6 +95,7 @@ import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest; import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest; +import org.whispersystems.textsecuregcm.entities.EncryptedUsername; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.IncomingMessage; import org.whispersystems.textsecuregcm.entities.RegistrationLock; @@ -1081,7 +1084,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/1234") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER, "bar")) .put(Entity.entity(new AccountAttributes(), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); verify(accountsManager).create(eq(SENDER), eq("bar"), any(), any(), anyList()); @@ -1094,7 +1097,7 @@ class AccountControllerTest { final Response response = resources.getJerseyTest() .target("/v1/accounts/code/1234") .request() - .header("Authorization", "This is not a valid authorization header") + .header(HttpHeaders.AUTHORIZATION, "This is not a valid authorization header") .put(Entity.entity(new AccountAttributes(), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(401); @@ -1106,7 +1109,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/1234") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_OLD, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_OLD, "bar")) .put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1130,7 +1133,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/1111") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER, "bar")) .put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1155,7 +1158,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/666666") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) .put(Entity.entity( new AccountAttributes(false, 3333, null, HexFormat.of().formatHex(registration_lock_key), true, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -1180,7 +1183,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/666666") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) .put(Entity.entity( new AccountAttributes(false, 3333, null, HexFormat.of().formatHex(registration_lock_key), true, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -1214,7 +1217,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/666666") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) .put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -1241,7 +1244,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/666666") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) .put(Entity.entity(new AccountAttributes(false, 3333, null, HexFormat.of().formatHex(new byte[32]), true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1268,7 +1271,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/666666") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar")) .put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1305,7 +1308,7 @@ class AccountControllerTest { .target("/v1/accounts/code/1234") .queryParam("transfer", true) .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) .put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1329,7 +1332,7 @@ class AccountControllerTest { .target("/v1/accounts/code/1234") .queryParam("transfer", true) .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) .put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1352,7 +1355,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/code/1234") .request() - .header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar")) .put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1375,7 +1378,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -1397,7 +1400,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1415,7 +1418,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1434,7 +1437,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); @@ -1452,7 +1455,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1476,7 +1479,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1512,7 +1515,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1547,7 +1550,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1587,7 +1590,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1626,7 +1629,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null), MediaType.APPLICATION_JSON_TYPE)); @@ -1675,7 +1678,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest( number, code, null, pniIdentityKey, deviceMessages, @@ -1729,7 +1732,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/number") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.entity(new ChangePhoneNumberRequest( AuthHelper.VALID_NUMBER, code, null, pniIdentityKey, deviceMessages, @@ -1754,7 +1757,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/registration_lock/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234"))); assertThat(response.getStatus()).isEqualTo(204); @@ -1776,7 +1779,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/registration_lock/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new RegistrationLock("313"))); assertThat(response.getStatus()).isEqualTo(422); @@ -1788,7 +1791,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/registration_lock/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234"))); assertThat(response.getStatus()).isEqualTo(401); @@ -1800,7 +1803,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/gcm/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .put(Entity.json(new GcmRegistrationId("z000"))); assertThat(response.getStatus()).isEqualTo(204); @@ -1815,7 +1818,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/gcm/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .put(Entity.json("{}")); assertThat(response.getStatus()).isEqualTo(422); @@ -1828,7 +1831,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/apn/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .put(Entity.json(new ApnRegistrationId("first", "second"))); assertThat(response.getStatus()).isEqualTo(204); @@ -1844,7 +1847,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/apn/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .put(Entity.json(new ApnRegistrationId("first", null))); assertThat(response.getStatus()).isEqualTo(204); @@ -1870,7 +1873,7 @@ class AccountControllerTest { final Response response = resources.getJerseyTest() .target(path) .request() - .header("Authorization", AuthHelper.getAuthHeader(aci, password)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(aci, password)) .get(); assertThat(response.getStatus()).isEqualTo(expectedHttpStatusCode); @@ -1889,6 +1892,145 @@ class AccountControllerTest { ); } + static Stream testSetUsernameLink() { + return Stream.of( + Arguments.of(false, true, true, 32, 401), + Arguments.of(true, true, false, 32, 409), + Arguments.of(true, true, true, 129, 422), + Arguments.of(true, true, true, 0, 422), + Arguments.of(true, false, true, 32, 429), + Arguments.of(true, true, true, 128, 200) + ); + } + + @ParameterizedTest + @MethodSource + public void testSetUsernameLink( + final boolean auth, + final boolean passRateLimiting, + final boolean setUsernameHash, + final int payloadSize, + final int expectedStatus) throws Exception { + + // checking if rate limiting needs to pass or fail for this test + if (passRateLimiting) { + MockUtils.updateRateLimiterResponseToAllow( + rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID); + } else { + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID, Duration.ofMinutes(10), false); + } + + // checking if username is to be set for this test + if (setUsernameHash) { + when(AuthHelper.VALID_ACCOUNT.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH_1)); + } else { + when(AuthHelper.VALID_ACCOUNT.getUsernameHash()).thenReturn(Optional.empty()); + } + + final Invocation.Builder builder = resources.getJerseyTest() + .target("/v1/accounts/username_link") + .request(); + + // checking if auth is needed for this test + if (auth) { + builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + } + + // make sure `update()` works + doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any()); + + final Response put = builder.put(Entity.json(new EncryptedUsername(RandomUtils.nextBytes(payloadSize)))); + + assertEquals(expectedStatus, put.getStatus()); + } + + static Stream testDeleteUsernameLink() { + return Stream.of( + Arguments.of(false, true, 401), + Arguments.of(true, false, 429), + Arguments.of(true, true, 204) + ); + } + + @ParameterizedTest + @MethodSource + public void testDeleteUsernameLink( + final boolean auth, + final boolean passRateLimiting, + final int expectedStatus) throws Exception { + + // checking if rate limiting needs to pass or fail for this test + if (passRateLimiting) { + MockUtils.updateRateLimiterResponseToAllow( + rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID); + } else { + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID, Duration.ofMinutes(10), false); + } + + final Invocation.Builder builder = resources.getJerseyTest() + .target("/v1/accounts/username_link") + .request(); + + // checking if auth is needed for this test + if (auth) { + builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + } + + // make sure `update()` works + doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any()); + + final Response delete = builder.delete(); + + assertEquals(expectedStatus, delete.getStatus()); + } + + static Stream testLookupUsernameLink() { + return Stream.of( + Arguments.of(false, true, true, true, 403), + Arguments.of(true, false, true, true, 429), + Arguments.of(true, true, false, true, 404), + Arguments.of(true, true, true, false, 404), + Arguments.of(true, true, true, true, 200) + ); + } + + @ParameterizedTest + @MethodSource + public void testLookupUsernameLink( + final boolean stayUnauthenticated, + final boolean passRateLimiting, + final boolean validUuidInput, + final boolean locateLinkByUuid, + final int expectedStatus) throws Exception { + + MockUtils.updateRateLimiterResponseToAllow( + rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, NICE_HOST); + MockUtils.updateRateLimiterResponseToFail( + rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, RATE_LIMITED_IP_HOST, Duration.ofMinutes(10), false); + + final String uuid = validUuidInput ? UUID.randomUUID().toString() : "invalid-uuid"; + + if (validUuidInput && locateLinkByUuid) { + final Account account = mock(Account.class); + doReturn(Optional.of(RandomUtils.nextBytes(16))).when(account).getEncryptedUsername(); + doReturn(Optional.of(account)).when(accountsManager).getByUsernameLinkHandle(eq(UUID.fromString(uuid))); + } + + final Invocation.Builder builder = resources.getJerseyTest() + .target("/v1/accounts/username_link/" + uuid) + .request(); + if (!stayUnauthenticated) { + builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + } + final Response get = builder + .header(HttpHeaders.X_FORWARDED_FOR, passRateLimiting ? NICE_HOST : RATE_LIMITED_IP_HOST) + .get(); + + assertEquals(expectedStatus, get.getStatus()); + } + @Test void testReserveUsernameHash() throws UsernameHashNotAvailableException { when(accountsManager.reserveUsernameHash(any(), any())) @@ -1897,7 +2039,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ReserveUsernameHashRequest(List.of(USERNAME_HASH_1, USERNAME_HASH_2)))); assertThat(response.getStatus()).isEqualTo(200); assertThat(response.readEntity(ReserveUsernameHashResponse.class)) @@ -1912,7 +2054,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ReserveUsernameHashRequest(List.of(USERNAME_HASH_1, USERNAME_HASH_2)))); assertThat(response.getStatus()).isEqualTo(409); } @@ -1924,7 +2066,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ReserveUsernameHashRequest(usernameHashes))); assertThat(response.getStatus()).isEqualTo(422); } @@ -1943,7 +2085,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ReserveUsernameHashRequest(usernameHashes))); assertThat(response.getStatus()).isEqualTo(422); } @@ -1954,7 +2096,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ReserveUsernameHashRequest(null))); assertThat(response.getStatus()).isEqualTo(422); } @@ -1965,7 +2107,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/reserve") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json( // Has '+' and '='characters which are invalid in base64url """ @@ -1986,7 +2128,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF))); assertThat(response.getStatus()).isEqualTo(200); assertArrayEquals(response.readEntity(UsernameHashResponse.class).usernameHash(), USERNAME_HASH_1); @@ -2002,7 +2144,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF))); assertThat(response.getStatus()).isEqualTo(409); verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1); @@ -2017,7 +2159,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF))); assertThat(response.getStatus()).isEqualTo(410); verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1); @@ -2029,7 +2171,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json( // Has '+' and '='characters which are invalid in base64url """ @@ -2049,7 +2191,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ConfirmUsernameHashRequest(usernameHash, ZK_PROOF))); assertThat(response.getStatus()).isEqualTo(422); verifyNoInteractions(usernameZkProofVerifier); @@ -2062,7 +2204,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/confirm") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF))); assertThat(response.getStatus()).isEqualTo(422); verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1); @@ -2074,7 +2216,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .delete(); assertThat(response.getStatus()).isEqualTo(204); @@ -2087,7 +2229,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/username_hash/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD)) .delete(); assertThat(response.getStatus()).isEqualTo(401); @@ -2099,7 +2241,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/attributes/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null))); assertThat(response.getStatus()).isEqualTo(204); @@ -2111,7 +2253,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/attributes/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD)) .put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null))); assertThat(response.getStatus()).isEqualTo(204); @@ -2124,7 +2266,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/attributes/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD)) .put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null) .withRecoveryPassword(recoveryPassword))); @@ -2138,7 +2280,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/attributes/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new AccountAttributes(false, 2222, null, null, false, null))); assertThat(response.getStatus()).isEqualTo(204); @@ -2150,7 +2292,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/attributes/") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .put(Entity.json(new AccountAttributes(false, 2222, null, null, false, null) .withUnidentifiedAccessKey(new byte[7]))); @@ -2163,7 +2305,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/me") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .delete(); assertThat(response.getStatus()).isEqualTo(204); @@ -2178,7 +2320,7 @@ class AccountControllerTest { resources.getJerseyTest() .target("/v1/accounts/me") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .delete(); assertThat(response.getStatus()).isEqualTo(500); @@ -2300,7 +2442,7 @@ class AccountControllerTest { assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", UUID.randomUUID())) .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .head() .getStatus()).isEqualTo(400); @@ -2352,7 +2494,7 @@ class AccountControllerTest { assertThat(resources.getJerseyTest() .target(String.format("/v1/accounts/username_hash/%s", USERNAME_HASH_1)) .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1") .get() .getStatus()).isEqualTo(400); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 7ff4e13f3..3327c2dd5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -12,6 +12,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -20,6 +21,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.uuid.UUIDComparator; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; +import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -32,12 +34,21 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.apache.commons.lang3.RandomUtils; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; +import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; +import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; @@ -68,6 +79,8 @@ class AccountsTest { private static final int SCAN_PAGE_SIZE = 1; + private static final AtomicInteger ACCOUNT_COUNTER = new AtomicInteger(1); + @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( @@ -100,6 +113,88 @@ class AccountsTest { SCAN_PAGE_SIZE); } + @Test + public void testStoreAndLookupUsernameLink() throws Exception { + final Account account = nextRandomAccount(); + account.setUsernameHash(RandomUtils.nextBytes(16)); + accounts.create(account); + + final BiConsumer, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> { + assertTrue(maybeAccount.isPresent()); + assertTrue(maybeAccount.get().getEncryptedUsername().isPresent()); + assertEquals(account.getUuid(), maybeAccount.get().getUuid()); + assertArrayEquals(expectedEncryptedUsername, maybeAccount.get().getEncryptedUsername().get()); + }; + + // creating a username link, storing it, checking that it can be looked up + final UUID linkHandle1 = UUID.randomUUID(); + final byte[] encruptedUsername1 = RandomUtils.nextBytes(32); + account.setUsernameLinkDetails(linkHandle1, encruptedUsername1); + accounts.update(account); + validator.accept(accounts.getByUsernameLinkHandle(linkHandle1), encruptedUsername1); + + // updating username link, storing new one, checking that it can be looked up, checking that old one can't be looked up + final UUID linkHandle2 = UUID.randomUUID(); + final byte[] encruptedUsername2 = RandomUtils.nextBytes(32); + account.setUsernameLinkDetails(linkHandle2, encruptedUsername2); + accounts.update(account); + validator.accept(accounts.getByUsernameLinkHandle(linkHandle2), encruptedUsername2); + assertTrue(accounts.getByUsernameLinkHandle(linkHandle1).isEmpty()); + + // deleting username link, checking it can't be looked up by either handle + account.setUsernameLinkDetails(null, null); + accounts.update(account); + assertTrue(accounts.getByUsernameLinkHandle(linkHandle1).isEmpty()); + assertTrue(accounts.getByUsernameLinkHandle(linkHandle2).isEmpty()); + } + + @Test + public void testUsernameLinksViaAccountsManager() throws Exception { + final AccountsManager accountsManager = new AccountsManager( + accounts, + mock(PhoneNumberIdentifiers.class), + mock(FaultTolerantRedisCluster.class), + mock(DeletedAccountsManager.class), + mock(Keys.class), + mock(MessagesManager.class), + mock(ProfilesManager.class), + mock(StoredVerificationCodeManager.class), + mock(SecureStorageClient.class), + mock(SecureBackupClient.class), + mock(SecureValueRecovery2Client.class), + mock(ClientPresenceManager.class), + mock(ExperimentEnrollmentManager.class), + mock(RegistrationRecoveryPasswordsManager.class), + mock(Clock.class)); + + final Account account = nextRandomAccount(); + account.setUsernameHash(RandomUtils.nextBytes(16)); + accounts.create(account); + + final UUID linkHandle = UUID.randomUUID(); + final byte[] encruptedUsername = RandomUtils.nextBytes(32); + accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, encruptedUsername)); + + final Optional maybeAccount = accountsManager.getByUsernameLinkHandle(linkHandle); + assertTrue(maybeAccount.isPresent()); + assertTrue(maybeAccount.get().getEncryptedUsername().isPresent()); + assertArrayEquals(encruptedUsername, maybeAccount.get().getEncryptedUsername().get()); + + // making some unrelated change and updating account to check that username link data is still there + final Optional accountToChange = accountsManager.getByAccountIdentifier(account.getUuid()); + assertTrue(accountToChange.isPresent()); + accountsManager.update(accountToChange.get(), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber())); + final Optional accountAfterChange = accountsManager.getByUsernameLinkHandle(linkHandle); + assertTrue(accountAfterChange.isPresent()); + assertTrue(accountAfterChange.get().getEncryptedUsername().isPresent()); + assertArrayEquals(encruptedUsername, accountAfterChange.get().getEncryptedUsername().get()); + + // now deleting the link + final Optional accountToDeleteLink = accountsManager.getByAccountIdentifier(account.getUuid()); + accountsManager.update(accountToDeleteLink.get(), a -> a.setUsernameLinkDetails(null, null)); + assertTrue(accounts.getByUsernameLinkHandle(linkHandle).isEmpty()); + } + @Test void testStore() { Device device = generateDevice(1); @@ -818,19 +913,24 @@ class AccountsTest { assertThat(account.getUsernameHash()).isEmpty(); } - private Device generateDevice(long id) { + private static Device generateDevice(long id) { return DevicesHelper.createDevice(id); } - private Account generateAccount(String number, UUID uuid, final UUID pni) { + private static Account nextRandomAccount() { + final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement()); + return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID()); + } + + private static Account generateAccount(String number, UUID uuid, final UUID pni) { Device device = generateDevice(1); return generateAccount(number, uuid, pni, List.of(device)); } - private Account generateAccount(String number, UUID uuid, final UUID pni, List devices) { - byte[] unidentifiedAccessKey = new byte[16]; - Random random = new Random(System.currentTimeMillis()); - Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255)); + private static Account generateAccount(String number, UUID uuid, final UUID pni, List devices) { + final byte[] unidentifiedAccessKey = new byte[16]; + final Random random = new Random(System.currentTimeMillis()); + Arrays.fill(unidentifiedAccessKey, (byte) random.nextInt(255)); return AccountsHelper.generateTestAccount(number, uuid, pni, devices, unidentifiedAccessKey); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java index 9cff9c791..7024fe994 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java @@ -23,11 +23,29 @@ public final class DynamoDbExtensionSchema { ACCOUNTS("accounts_test", Accounts.KEY_ACCOUNT_UUID, null, - List.of(AttributeDefinition.builder() - .attributeName(Accounts.KEY_ACCOUNT_UUID) - .attributeType(ScalarAttributeType.B) - .build()), - List.of(), List.of()), + List.of( + AttributeDefinition.builder() + .attributeName(Accounts.KEY_ACCOUNT_UUID) + .attributeType(ScalarAttributeType.B) + .build(), + AttributeDefinition.builder() + .attributeName(Accounts.ATTR_USERNAME_LINK_UUID) + .attributeType(ScalarAttributeType.B) + .build()), + List.of( + GlobalSecondaryIndex.builder() + .indexName(Accounts.USERNAME_LINK_TO_UUID_INDEX) + .keySchema( + KeySchemaElement.builder() + .attributeName(Accounts.ATTR_USERNAME_LINK_UUID) + .keyType(KeyType.HASH) + .build() + ) + .projection(Projection.builder().projectionType(ProjectionType.KEYS_ONLY).build()) + .provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(10L).writeCapacityUnits(10L).build()) + .build() + ), + List.of()), DELETED_ACCOUNTS("deleted_accounts_test", DeletedAccounts.KEY_ACCOUNT_E164, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 15f1324a4..927be5d50 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -114,6 +114,7 @@ public class AccountsHelper { case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing); case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing); case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing); + case "getUsernameHash" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing); case "getDevices" -> when(updatedAccount.getDevices()).thenAnswer(stubbing); case "getDevice" -> when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); case "getMasterDevice" -> when(updatedAccount.getMasterDevice()).thenAnswer(stubbing); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index c3161821f..16466b8fd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -11,7 +11,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import java.time.Duration; -import java.util.Optional; +import java.util.UUID; import org.apache.commons.lang3.RandomUtils; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes; @@ -50,7 +50,20 @@ public final class MockUtils { final RateLimiters.For handle, final String input) { final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); - doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).forDescriptor(eq(handle)); + doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); + try { + doNothing().when(mockRateLimiter).validate(eq(input)); + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } + + public static void updateRateLimiterResponseToAllow( + final RateLimiters rateLimitersMock, + final RateLimiters.For handle, + final UUID input) { + final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); + doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); try { doNothing().when(mockRateLimiter).validate(eq(input)); } catch (final RateLimitExceededException e) { @@ -73,6 +86,21 @@ public final class MockUtils { } } + public static void updateRateLimiterResponseToFail( + final RateLimiters rateLimitersMock, + final RateLimiters.For handle, + final UUID input, + final Duration retryAfter, + final boolean legacyStatusCode) { + final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); + doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle)); + try { + doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input)); + } catch (final RateLimitExceededException e) { + throw new RuntimeException(e); + } + } + public static SecretBytes randomSecretBytes(final int size) { return new SecretBytes(RandomUtils.nextBytes(size)); }