username links API

This commit is contained in:
Sergey Skrobotov 2023-06-02 10:15:09 -07:00
parent ecd207f0a1
commit 47cc7fd615
13 changed files with 653 additions and 142 deletions

View File

@ -20,6 +20,8 @@ import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.io.IOException;
import java.security.SecureRandom;
import java.time.Clock;
@ -32,6 +34,7 @@ import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletionException;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
@ -77,6 +80,7 @@ import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.DeviceName;
import org.whispersystems.textsecuregcm.entities.EncryptedUsername;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
@ -85,6 +89,7 @@ import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.UsernameHashResponse;
import org.whispersystems.textsecuregcm.entities.UsernameLinkHandle;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -421,7 +426,7 @@ public class AccountController {
}
if (availableForTransfer.orElse(false) && existingAccount.map(Account::isTransferSupported).orElse(false)) {
throw new WebApplicationException(Response.status(409).build());
throw new WebApplicationException(Status.CONFLICT);
}
rateLimiters.getVerifyLimiter().clear(number);
@ -699,7 +704,8 @@ public class AccountController {
@DELETE
@Path("/username_hash")
@Produces(MediaType.APPLICATION_JSON)
public void deleteUsernameHash(@Auth AuthenticatedAccount auth) {
public void deleteUsernameHash(final @Auth AuthenticatedAccount auth) {
clearUsernameLink(auth.getAccount());
accounts.clearUsernameHash(auth.getAccount());
}
@ -736,9 +742,10 @@ public class AccountController {
@Path("/username_hash/confirm")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public UsernameHashResponse confirmUsernameHash(@Auth AuthenticatedAccount auth,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent,
@NotNull @Valid ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException {
public UsernameHashResponse confirmUsernameHash(
@Auth final AuthenticatedAccount auth,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent,
@NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) throws RateLimitExceededException {
rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid());
try {
@ -747,6 +754,11 @@ public class AccountController {
throw new WebApplicationException(Response.status(422).build());
}
// Whenever a valid request for a username change arrives,
// we're making sure to clear username link. This may happen before and username changes are written to the db
// but verifying zk proof means that request itself is valid from the client's perspective
clearUsernameLink(auth.getAccount());
try {
final Account account = accounts.confirmReservedUsernameHash(auth.getAccount(), confirmRequest.usernameHash());
return account
@ -796,6 +808,90 @@ public class AccountController {
.orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND));
}
@Timed
@PUT
@Path("/username_link")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Operation(
summary = "Set username link",
description = """
Authenticated endpoint. For the given encrypted username generates a username link handle.
Username link handle could be used to lookup the encrypted username.
An account can only have one username link at a time. Calling this endpoint will reset previously stored
encrypted username and deactivate previous link handle.
"""
)
@ApiResponse(responseCode = "200", description = "Username Link updated successfully.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "409", description = "Username is not set for the account.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public UsernameLinkHandle updateUsernameLink(
@Auth final AuthenticatedAccount auth,
@NotNull @Valid final EncryptedUsername encryptedUsername) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
// check if username hash is set for the account
if (auth.getAccount().getUsernameHash().isEmpty()) {
throw new WebApplicationException(Status.CONFLICT);
}
final UUID usernameLinkHandle = UUID.randomUUID();
updateUsernameLink(auth.getAccount(), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
return new UsernameLinkHandle(usernameLinkHandle);
}
@Timed
@DELETE
@Path("/username_link")
@Operation(
summary = "Delete username link",
description = """
Authenticated endpoint. Deletes username link for the given account: previously store encrypted username is deleted
and username link handle is deactivated.
"""
)
@ApiResponse(responseCode = "204", description = "Username Link successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public void deleteUsernameLink(@Auth final AuthenticatedAccount auth) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
clearUsernameLink(auth.getAccount());
}
@Timed
@GET
@Path("/username_link/{uuid}")
@Produces(MediaType.APPLICATION_JSON)
@RateLimitedByIp(RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP)
@Operation(
summary = "Lookup username link",
description = """
Enforced unauthenticated endpoint. For the given username link handle, looks up the database for an associated encrypted username.
If found, encrypted username is returned, otherwise responds with 404 Not Found.
"""
)
@ApiResponse(responseCode = "200", description = "Username link with the given handle was found.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "404", description = "Username link was not found for the given handle.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public EncryptedUsername lookupUsernameLink(
@Auth Optional<AuthenticatedAccount> authenticatedAccount,
@PathParam("uuid") final UUID usernameLinkHandle) {
final Optional<byte[]> maybeEncryptedUsername = accounts.getByUsernameLinkHandle(usernameLinkHandle)
.flatMap(Account::getEncryptedUsername);
if (authenticatedAccount.isPresent()) {
throw new ForbiddenException("must not use authenticated connection for connection graph revealing operations");
}
if (maybeEncryptedUsername.isEmpty()) {
throw new WebApplicationException(Status.NOT_FOUND);
}
return new EncryptedUsername(maybeEncryptedUsername.get());
}
@HEAD
@Path("/account/{uuid}")
@RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
@ -915,6 +1011,20 @@ public class AccountController {
}
}
private void clearUsernameLink(final Account account) {
updateUsernameLink(account, null, null);
}
private void updateUsernameLink(
final Account account,
@Nullable final UUID usernameLinkHandle,
@Nullable final byte[] encryptedUsername) {
if ((encryptedUsername == null) ^ (usernameLinkHandle == null)) {
throw new IllegalStateException("Both or neither arguments must be null");
}
accounts.update(account, a -> a.setUsernameLinkDetails(usernameLinkHandle, encryptedUsername));
}
private void rethrowRateLimitException(final CompletionException completionException)
throws RateLimitExceededException {

View File

@ -0,0 +1,12 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
public record EncryptedUsername(@NotNull @Size(min = 1, max = 128) byte[] usernameLinkEncryptedValue) {
}

View File

@ -0,0 +1,12 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import java.util.UUID;
import javax.validation.constraints.NotNull;
public record UsernameLinkHandle(@NotNull UUID usernameLinkHandle) {
}

View File

@ -37,6 +37,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofMillis(60))),
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofMillis(500))),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofMillis(500))),

View File

@ -7,6 +7,8 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
@ -16,8 +18,6 @@ import java.util.Optional;
import java.util.UUID;
import java.util.function.Predicate;
import javax.annotation.Nullable;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
@ -54,6 +54,14 @@ public class Account {
@Nullable
private byte[] reservedUsernameHash;
@JsonIgnore
@Nullable
private UUID usernameLinkHandle;
@JsonProperty("eu")
@Nullable
private byte[] encryptedUsername;
@JsonProperty
private List<Device> devices = new ArrayList<>();
@ -162,6 +170,38 @@ public class Account {
this.reservedUsernameHash = reservedUsernameHash;
}
@Nullable
public UUID getUsernameLinkHandle() {
requireNotStale();
return usernameLinkHandle;
}
public Optional<byte[]> getEncryptedUsername() {
requireNotStale();
return Optional.ofNullable(encryptedUsername);
}
public void setUsernameLinkDetails(@Nullable final UUID usernameLinkHandle, @Nullable final byte[] encryptedUsername) {
requireNotStale();
if ((usernameLinkHandle == null) ^ (encryptedUsername == null)) {
throw new IllegalArgumentException("Both or neither arguments must be null");
}
if (usernameHash == null && encryptedUsername != null) {
throw new IllegalArgumentException("usernameHash field must be set to store username link");
}
this.encryptedUsername = encryptedUsername;
this.usernameLinkHandle = usernameLinkHandle;
}
/*
* This method is intentionally left package-private so that it's only used
* when Account is read from DB
*/
void setUsernameLinkHandle(@Nullable final UUID usernameLinkHandle) {
requireNotStale();
this.usernameLinkHandle = usernameLinkHandle;
}
public void addDevice(Device device) {
requireNotStale();

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
@ -45,6 +45,8 @@ import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
@ -68,6 +70,7 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber"));
private static final Timer GET_BY_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameHash"));
private static final Timer GET_BY_USERNAME_LINK_HANDLE_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameLinkHandle"));
private static final Timer GET_BY_PNI_TIMER = Metrics.timer(name(Accounts.class, "getByPni"));
private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(Accounts.class, "getByUuid"));
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
@ -82,6 +85,8 @@ public class Accounts extends AbstractDynamoDbStore {
static final String KEY_ACCOUNT_UUID = "U";
// uuid, attribute on account table, primary key for PNI table
static final String ATTR_PNI_UUID = "PNI";
// uuid of the current username link or null
static final String ATTR_USERNAME_LINK_UUID = "UL";
// phone number
static final String ATTR_ACCOUNT_E164 = "P";
// account, serialized to JSON
@ -99,6 +104,8 @@ public class Accounts extends AbstractDynamoDbStore {
// time to live; number
static final String ATTR_TTL = "TTL";
static final String USERNAME_LINK_TO_UUID_INDEX = "ul_to_u";
private final Clock clock;
private final DynamoDbAsyncClient asyncClient;
@ -525,20 +532,28 @@ public class Accounts extends AbstractDynamoDbStore {
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)));
final String updateExpression;
final StringBuilder updateExpressionBuilder = new StringBuilder("SET #data = :data, #cds = :cds");
if (account.getUnidentifiedAccessKey().isPresent()) {
// if it's present in the account, also set the uak
attrNames.put("#uak", ATTR_UAK);
attrValues.put(":uak", AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
updateExpression = "SET #data = :data, #cds = :cds, #uak = :uak ADD #version :version_increment";
} else {
updateExpression = "SET #data = :data, #cds = :cds ADD #version :version_increment";
updateExpressionBuilder.append(", #uak = :uak");
}
if (account.getEncryptedUsername().isPresent() && account.getUsernameLinkHandle() != null) {
attrNames.put("#ul", ATTR_USERNAME_LINK_UUID);
attrValues.put(":ul", AttributeValues.fromUUID(account.getUsernameLinkHandle()));
updateExpressionBuilder.append(", #ul = :ul");
}
updateExpressionBuilder.append(" ADD #version :version_increment");
if (account.getEncryptedUsername().isEmpty() || account.getUsernameLinkHandle() == null) {
attrNames.put("#ul", ATTR_USERNAME_LINK_UUID);
updateExpressionBuilder.append(" REMOVE #ul");
}
updateItemRequest = UpdateItemRequest.builder()
.tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression(updateExpression)
.updateExpression(updateExpressionBuilder.toString())
.conditionExpression("attribute_exists(#number) AND #version = :version")
.expressionAttributeNames(attrNames)
.expressionAttributeValues(attrValues)
@ -630,11 +645,18 @@ public class Accounts extends AbstractDynamoDbStore {
);
}
@Nonnull
public Optional<Account> getByUsernameLinkHandle(final UUID usernameLinkHandle) {
return requireNonNull(GET_BY_USERNAME_LINK_HANDLE_TIMER.record(() ->
itemByGsiKey(accountsTableName, USERNAME_LINK_TO_UUID_INDEX, ATTR_USERNAME_LINK_UUID, AttributeValues.fromUUID(usernameLinkHandle))
.map(Accounts::fromItem)));
}
@Nonnull
public Optional<Account> getByAccountIdentifier(final UUID uuid) {
return requireNonNull(GET_BY_UUID_TIMER.record(() ->
itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid))
.map(Accounts::fromItem)));
itemByKey(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid))
.map(Accounts::fromItem)));
}
public void delete(final UUID uuid) {
@ -706,6 +728,33 @@ public class Accounts extends AbstractDynamoDbStore {
return Optional.ofNullable(response.item()).filter(m -> !m.isEmpty());
}
@Nonnull
private Optional<Map<String, AttributeValue>> itemByGsiKey(final String table, final String indexName, final String keyName, final AttributeValue keyValue) {
final QueryResponse response = db().query(QueryRequest.builder()
.tableName(table)
.indexName(indexName)
.keyConditionExpression("#gsiKey = :gsiValue")
.projectionExpression("#uuid")
.expressionAttributeNames(Map.of(
"#gsiKey", keyName,
"#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(
":gsiValue", keyValue))
.build());
if (response.count() == 0) {
return Optional.empty();
}
if (response.count() > 1) {
throw new IllegalStateException("More than one row located for GSI [%s], key-value pair [%s, %s]"
.formatted(indexName, keyName, keyValue));
}
final AttributeValue primaryKeyValue = response.items().get(0).get(KEY_ACCOUNT_UUID);
return itemByKey(table, KEY_ACCOUNT_UUID, primaryKeyValue);
}
@Nonnull
private TransactWriteItem buildAccountPut(
final Account account,
@ -854,6 +903,7 @@ public class Accounts extends AbstractDynamoDbStore {
account.setNumber(item.get(ATTR_ACCOUNT_E164).s(), phoneNumberIdentifierFromAttribute);
account.setUuid(accountIdentifier);
account.setUsernameHash(AttributeValues.getByteArray(item, ATTR_USERNAME_HASH, null));
account.setUsernameLinkHandle(AttributeValues.getUUID(item, ATTR_USERNAME_LINK_UUID, null));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE))
.map(AttributeValue::bool)

View File

@ -36,7 +36,6 @@ import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.ObjectUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -63,12 +62,14 @@ public class AccountsManager {
private static final Timer updateTimer = metricRegistry.timer(name(AccountsManager.class, "update"));
private static final Timer getByNumberTimer = metricRegistry.timer(name(AccountsManager.class, "getByNumber"));
private static final Timer getByUsernameHashTimer = metricRegistry.timer(name(AccountsManager.class, "getByUsernameHash"));
private static final Timer getByUsernameLinkHandleTimer = metricRegistry.timer(name(AccountsManager.class, "getByUsernameLinkHandle"));
private static final Timer getByUuidTimer = metricRegistry.timer(name(AccountsManager.class, "getByUuid"));
private static final Timer deleteTimer = metricRegistry.timer(name(AccountsManager.class, "delete"));
private static final Timer redisSetTimer = metricRegistry.timer(name(AccountsManager.class, "redisSet"));
private static final Timer redisNumberGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisNumberGet"));
private static final Timer redisUsernameHashGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUsernameHashGet"));
private static final Timer redisUsernameLinkHandleGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUsernameLinkHandleGet"));
private static final Timer redisPniGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisPniGet"));
private static final Timer redisUuidGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUuidGet"));
private static final Timer redisDeleteTimer = metricRegistry.timer(name(AccountsManager.class, "redisDelete"));
@ -620,55 +621,44 @@ public class AccountsManager {
});
}
public Optional<Account> getByE164(String number) {
try (Timer.Context ignored = getByNumberTimer.time()) {
Optional<Account> account = redisGetByE164(number);
if (account.isEmpty()) {
account = accounts.getByE164(number);
account.ifPresent(this::redisSet);
}
return account;
}
public Optional<Account> getByE164(final String number) {
return checkRedisThenAccounts(
getByNumberTimer,
() -> redisGetBySecondaryKey(getAccountMapKey(number), redisNumberGetTimer),
() -> accounts.getByE164(number)
);
}
public Optional<Account> getByPhoneNumberIdentifier(UUID pni) {
try (Timer.Context ignored = getByNumberTimer.time()) {
Optional<Account> account = redisGetByPhoneNumberIdentifier(pni);
public Optional<Account> getByPhoneNumberIdentifier(final UUID pni) {
return checkRedisThenAccounts(
getByNumberTimer,
() -> redisGetBySecondaryKey(getAccountMapKey(pni.toString()), redisPniGetTimer),
() -> accounts.getByPhoneNumberIdentifier(pni)
);
}
if (account.isEmpty()) {
account = accounts.getByPhoneNumberIdentifier(pni);
account.ifPresent(this::redisSet);
}
return account;
}
public Optional<Account> getByUsernameLinkHandle(final UUID usernameLinkHandle) {
return checkRedisThenAccounts(
getByUsernameLinkHandleTimer,
() -> redisGetBySecondaryKey(getAccountMapKey(usernameLinkHandle.toString()), redisUsernameLinkHandleGetTimer),
() -> accounts.getByUsernameLinkHandle(usernameLinkHandle)
);
}
public Optional<Account> getByUsernameHash(final byte[] usernameHash) {
try (final Timer.Context ignored = getByUsernameHashTimer.time()) {
Optional<Account> account = redisGetByUsernameHash(usernameHash);
if (account.isEmpty()) {
account = accounts.getByUsernameHash(usernameHash);
account.ifPresent(this::redisSet);
}
return account;
}
return checkRedisThenAccounts(
getByUsernameHashTimer,
() -> redisGetBySecondaryKey(getUsernameHashAccountMapKey(usernameHash), redisUsernameHashGetTimer),
() -> accounts.getByUsernameHash(usernameHash)
);
}
public Optional<Account> getByAccountIdentifier(UUID uuid) {
try (Timer.Context ignored = getByUuidTimer.time()) {
Optional<Account> account = redisGetByAccountIdentifier(uuid);
if (account.isEmpty()) {
account = accounts.getByAccountIdentifier(uuid);
account.ifPresent(this::redisSet);
}
return account;
}
public Optional<Account> getByAccountIdentifier(final UUID uuid) {
return checkRedisThenAccounts(
getByUuidTimer,
() -> redisGetByAccountIdentifier(uuid),
() -> accounts.getByAccountIdentifier(uuid)
);
}
public UUID getPhoneNumberIdentifier(String e164) {
@ -758,24 +748,25 @@ public class AccountsManager {
}
}
private Optional<Account> redisGetByPhoneNumberIdentifier(UUID uuid) {
return redisGetBySecondaryKey(getAccountMapKey(uuid.toString()), redisPniGetTimer);
private Optional<Account> checkRedisThenAccounts(
final Timer overallTimer,
final Supplier<Optional<Account>> resolveFromRedis,
final Supplier<Optional<Account>> resolveFromAccounts) {
try (final Timer.Context ignored = overallTimer.time()) {
Optional<Account> account = resolveFromRedis.get();
if (account.isEmpty()) {
account = resolveFromAccounts.get();
account.ifPresent(this::redisSet);
}
return account;
}
}
private Optional<Account> redisGetByE164(String e164) {
return redisGetBySecondaryKey(getAccountMapKey(e164), redisNumberGetTimer);
}
private Optional<Account> redisGetByUsernameHash(byte[] usernameHash) {
return redisGetBySecondaryKey(getUsernameHashAccountMapKey(usernameHash), redisUsernameHashGetTimer);
}
private Optional<Account> redisGetBySecondaryKey(String secondaryKey, Timer timer) {
try (Timer.Context ignored = timer.time()) {
final String uuid = cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey));
if (uuid != null) return redisGetByAccountIdentifier(UUID.fromString(uuid));
else return Optional.empty();
private Optional<Account> redisGetBySecondaryKey(final String secondaryKey, final Timer timer) {
try (final Timer.Context ignored = timer.time()) {
return Optional.ofNullable(cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey)))
.map(UUID::fromString)
.flatMap(this::getByAccountIdentifier);
} catch (IllegalArgumentException e) {
logger.warn("Deserialization error", e);
return Optional.empty();

View File

@ -1,10 +1,15 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.signal.libsignal.usernames.BaseUsernameException;
import org.signal.libsignal.usernames.Username;
public class UsernameHashZkProofVerifier {
public void verifyProof(byte[] proof, byte[] hash) throws BaseUsernameException {
public void verifyProof(final byte[] proof, final byte[] hash) throws BaseUsernameException {
Username.verifyProof(proof, hash);
}
}

View File

@ -15,6 +15,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
@ -54,6 +55,7 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomUtils;
@ -93,6 +95,7 @@ import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
import org.whispersystems.textsecuregcm.entities.ConfirmUsernameHashRequest;
import org.whispersystems.textsecuregcm.entities.EncryptedUsername;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
@ -1081,7 +1084,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/1234")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(), MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(accountsManager).create(eq(SENDER), eq("bar"), any(), any(), anyList());
@ -1094,7 +1097,7 @@ class AccountControllerTest {
final Response response = resources.getJerseyTest()
.target("/v1/accounts/code/1234")
.request()
.header("Authorization", "This is not a valid authorization header")
.header(HttpHeaders.AUTHORIZATION, "This is not a valid authorization header")
.put(Entity.entity(new AccountAttributes(), MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(401);
@ -1106,7 +1109,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/1234")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_OLD, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_OLD, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1130,7 +1133,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/1111")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1155,7 +1158,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/666666")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(
new AccountAttributes(false, 3333, null, HexFormat.of().formatHex(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -1180,7 +1183,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/666666")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(
new AccountAttributes(false, 3333, null, HexFormat.of().formatHex(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -1214,7 +1217,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/666666")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -1241,7 +1244,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/666666")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null,
HexFormat.of().formatHex(new byte[32]), true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1268,7 +1271,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/666666")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1305,7 +1308,7 @@ class AccountControllerTest {
.target("/v1/accounts/code/1234")
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1329,7 +1332,7 @@ class AccountControllerTest {
.target("/v1/accounts/code/1234")
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1352,7 +1355,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/code/1234")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1375,7 +1378,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -1397,7 +1400,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1415,7 +1418,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1434,7 +1437,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(AuthHelper.VALID_NUMBER, "567890", null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
@ -1452,7 +1455,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1476,7 +1479,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1512,7 +1515,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1547,7 +1550,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, null, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1587,7 +1590,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1626,7 +1629,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(number, code, reglock, null, null, null, null, null),
MediaType.APPLICATION_JSON_TYPE));
@ -1675,7 +1678,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
number, code, null,
pniIdentityKey, deviceMessages,
@ -1729,7 +1732,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/number")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ChangePhoneNumberRequest(
AuthHelper.VALID_NUMBER, code, null,
pniIdentityKey, deviceMessages,
@ -1754,7 +1757,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/registration_lock/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234")));
assertThat(response.getStatus()).isEqualTo(204);
@ -1776,7 +1779,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/registration_lock/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new RegistrationLock("313")));
assertThat(response.getStatus()).isEqualTo(422);
@ -1788,7 +1791,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/registration_lock/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new RegistrationLock("1234567890123456789012345678901234567890123456789012345678901234")));
assertThat(response.getStatus()).isEqualTo(401);
@ -1800,7 +1803,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/gcm/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new GcmRegistrationId("z000")));
assertThat(response.getStatus()).isEqualTo(204);
@ -1815,7 +1818,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/gcm/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(Entity.json("{}"));
assertThat(response.getStatus()).isEqualTo(422);
@ -1828,7 +1831,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/apn/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new ApnRegistrationId("first", "second")));
assertThat(response.getStatus()).isEqualTo(204);
@ -1844,7 +1847,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/apn/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.put(Entity.json(new ApnRegistrationId("first", null)));
assertThat(response.getStatus()).isEqualTo(204);
@ -1870,7 +1873,7 @@ class AccountControllerTest {
final Response response = resources.getJerseyTest()
.target(path)
.request()
.header("Authorization", AuthHelper.getAuthHeader(aci, password))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(aci, password))
.get();
assertThat(response.getStatus()).isEqualTo(expectedHttpStatusCode);
@ -1889,6 +1892,145 @@ class AccountControllerTest {
);
}
static Stream<Arguments> testSetUsernameLink() {
return Stream.of(
Arguments.of(false, true, true, 32, 401),
Arguments.of(true, true, false, 32, 409),
Arguments.of(true, true, true, 129, 422),
Arguments.of(true, true, true, 0, 422),
Arguments.of(true, false, true, 32, 429),
Arguments.of(true, true, true, 128, 200)
);
}
@ParameterizedTest
@MethodSource
public void testSetUsernameLink(
final boolean auth,
final boolean passRateLimiting,
final boolean setUsernameHash,
final int payloadSize,
final int expectedStatus) throws Exception {
// checking if rate limiting needs to pass or fail for this test
if (passRateLimiting) {
MockUtils.updateRateLimiterResponseToAllow(
rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID);
} else {
MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID, Duration.ofMinutes(10), false);
}
// checking if username is to be set for this test
if (setUsernameHash) {
when(AuthHelper.VALID_ACCOUNT.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH_1));
} else {
when(AuthHelper.VALID_ACCOUNT.getUsernameHash()).thenReturn(Optional.empty());
}
final Invocation.Builder builder = resources.getJerseyTest()
.target("/v1/accounts/username_link")
.request();
// checking if auth is needed for this test
if (auth) {
builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
}
// make sure `update()` works
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any());
final Response put = builder.put(Entity.json(new EncryptedUsername(RandomUtils.nextBytes(payloadSize))));
assertEquals(expectedStatus, put.getStatus());
}
static Stream<Arguments> testDeleteUsernameLink() {
return Stream.of(
Arguments.of(false, true, 401),
Arguments.of(true, false, 429),
Arguments.of(true, true, 204)
);
}
@ParameterizedTest
@MethodSource
public void testDeleteUsernameLink(
final boolean auth,
final boolean passRateLimiting,
final int expectedStatus) throws Exception {
// checking if rate limiting needs to pass or fail for this test
if (passRateLimiting) {
MockUtils.updateRateLimiterResponseToAllow(
rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID);
} else {
MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.For.USERNAME_LINK_OPERATION, AuthHelper.VALID_UUID, Duration.ofMinutes(10), false);
}
final Invocation.Builder builder = resources.getJerseyTest()
.target("/v1/accounts/username_link")
.request();
// checking if auth is needed for this test
if (auth) {
builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
}
// make sure `update()` works
doReturn(AuthHelper.VALID_ACCOUNT).when(accountsManager).update(any(), any());
final Response delete = builder.delete();
assertEquals(expectedStatus, delete.getStatus());
}
static Stream<Arguments> testLookupUsernameLink() {
return Stream.of(
Arguments.of(false, true, true, true, 403),
Arguments.of(true, false, true, true, 429),
Arguments.of(true, true, false, true, 404),
Arguments.of(true, true, true, false, 404),
Arguments.of(true, true, true, true, 200)
);
}
@ParameterizedTest
@MethodSource
public void testLookupUsernameLink(
final boolean stayUnauthenticated,
final boolean passRateLimiting,
final boolean validUuidInput,
final boolean locateLinkByUuid,
final int expectedStatus) throws Exception {
MockUtils.updateRateLimiterResponseToAllow(
rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, NICE_HOST);
MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.For.USERNAME_LINK_LOOKUP_PER_IP, RATE_LIMITED_IP_HOST, Duration.ofMinutes(10), false);
final String uuid = validUuidInput ? UUID.randomUUID().toString() : "invalid-uuid";
if (validUuidInput && locateLinkByUuid) {
final Account account = mock(Account.class);
doReturn(Optional.of(RandomUtils.nextBytes(16))).when(account).getEncryptedUsername();
doReturn(Optional.of(account)).when(accountsManager).getByUsernameLinkHandle(eq(UUID.fromString(uuid)));
}
final Invocation.Builder builder = resources.getJerseyTest()
.target("/v1/accounts/username_link/" + uuid)
.request();
if (!stayUnauthenticated) {
builder.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
}
final Response get = builder
.header(HttpHeaders.X_FORWARDED_FOR, passRateLimiting ? NICE_HOST : RATE_LIMITED_IP_HOST)
.get();
assertEquals(expectedStatus, get.getStatus());
}
@Test
void testReserveUsernameHash() throws UsernameHashNotAvailableException {
when(accountsManager.reserveUsernameHash(any(), any()))
@ -1897,7 +2039,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ReserveUsernameHashRequest(List.of(USERNAME_HASH_1, USERNAME_HASH_2))));
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.readEntity(ReserveUsernameHashResponse.class))
@ -1912,7 +2054,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ReserveUsernameHashRequest(List.of(USERNAME_HASH_1, USERNAME_HASH_2))));
assertThat(response.getStatus()).isEqualTo(409);
}
@ -1924,7 +2066,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ReserveUsernameHashRequest(usernameHashes)));
assertThat(response.getStatus()).isEqualTo(422);
}
@ -1943,7 +2085,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ReserveUsernameHashRequest(usernameHashes)));
assertThat(response.getStatus()).isEqualTo(422);
}
@ -1954,7 +2096,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ReserveUsernameHashRequest(null)));
assertThat(response.getStatus()).isEqualTo(422);
}
@ -1965,7 +2107,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/reserve")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(
// Has '+' and '='characters which are invalid in base64url
"""
@ -1986,7 +2128,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF)));
assertThat(response.getStatus()).isEqualTo(200);
assertArrayEquals(response.readEntity(UsernameHashResponse.class).usernameHash(), USERNAME_HASH_1);
@ -2002,7 +2144,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF)));
assertThat(response.getStatus()).isEqualTo(409);
verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1);
@ -2017,7 +2159,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF)));
assertThat(response.getStatus()).isEqualTo(410);
verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1);
@ -2029,7 +2171,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(
// Has '+' and '='characters which are invalid in base64url
"""
@ -2049,7 +2191,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ConfirmUsernameHashRequest(usernameHash, ZK_PROOF)));
assertThat(response.getStatus()).isEqualTo(422);
verifyNoInteractions(usernameZkProofVerifier);
@ -2062,7 +2204,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/confirm")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new ConfirmUsernameHashRequest(USERNAME_HASH_1, ZK_PROOF)));
assertThat(response.getStatus()).isEqualTo(422);
verify(usernameZkProofVerifier).verifyProof(ZK_PROOF, USERNAME_HASH_1);
@ -2074,7 +2216,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.delete();
assertThat(response.getStatus()).isEqualTo(204);
@ -2087,7 +2229,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/username_hash/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.delete();
assertThat(response.getStatus()).isEqualTo(401);
@ -2099,7 +2241,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/attributes/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
@ -2111,7 +2253,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/attributes/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD))
.put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
@ -2124,7 +2266,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/attributes/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.UNDISCOVERABLE_UUID, AuthHelper.UNDISCOVERABLE_PASSWORD))
.put(Entity.json(new AccountAttributes(false, 2222, null, null, true, null)
.withRecoveryPassword(recoveryPassword)));
@ -2138,7 +2280,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/attributes/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new AccountAttributes(false, 2222, null, null, false, null)));
assertThat(response.getStatus()).isEqualTo(204);
@ -2150,7 +2292,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/attributes/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(new AccountAttributes(false, 2222, null, null, false, null)
.withUnidentifiedAccessKey(new byte[7])));
@ -2163,7 +2305,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/me")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.delete();
assertThat(response.getStatus()).isEqualTo(204);
@ -2178,7 +2320,7 @@ class AccountControllerTest {
resources.getJerseyTest()
.target("/v1/accounts/me")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.delete();
assertThat(response.getStatus()).isEqualTo(500);
@ -2300,7 +2442,7 @@ class AccountControllerTest {
assertThat(resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", UUID.randomUUID()))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.head()
.getStatus()).isEqualTo(400);
@ -2352,7 +2494,7 @@ class AccountControllerTest {
assertThat(resources.getJerseyTest()
.target(String.format("/v1/accounts/username_hash/%s", USERNAME_HASH_1))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1")
.get()
.getStatus()).isEqualTo(400);

View File

@ -12,6 +12,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -20,6 +21,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.uuid.UUIDComparator;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
@ -32,12 +34,21 @@ import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import org.apache.commons.lang3.RandomUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@ -68,6 +79,8 @@ class AccountsTest {
private static final int SCAN_PAGE_SIZE = 1;
private static final AtomicInteger ACCOUNT_COUNTER = new AtomicInteger(1);
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
@ -100,6 +113,88 @@ class AccountsTest {
SCAN_PAGE_SIZE);
}
@Test
public void testStoreAndLookupUsernameLink() throws Exception {
final Account account = nextRandomAccount();
account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account);
final BiConsumer<Optional<Account>, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> {
assertTrue(maybeAccount.isPresent());
assertTrue(maybeAccount.get().getEncryptedUsername().isPresent());
assertEquals(account.getUuid(), maybeAccount.get().getUuid());
assertArrayEquals(expectedEncryptedUsername, maybeAccount.get().getEncryptedUsername().get());
};
// creating a username link, storing it, checking that it can be looked up
final UUID linkHandle1 = UUID.randomUUID();
final byte[] encruptedUsername1 = RandomUtils.nextBytes(32);
account.setUsernameLinkDetails(linkHandle1, encruptedUsername1);
accounts.update(account);
validator.accept(accounts.getByUsernameLinkHandle(linkHandle1), encruptedUsername1);
// updating username link, storing new one, checking that it can be looked up, checking that old one can't be looked up
final UUID linkHandle2 = UUID.randomUUID();
final byte[] encruptedUsername2 = RandomUtils.nextBytes(32);
account.setUsernameLinkDetails(linkHandle2, encruptedUsername2);
accounts.update(account);
validator.accept(accounts.getByUsernameLinkHandle(linkHandle2), encruptedUsername2);
assertTrue(accounts.getByUsernameLinkHandle(linkHandle1).isEmpty());
// deleting username link, checking it can't be looked up by either handle
account.setUsernameLinkDetails(null, null);
accounts.update(account);
assertTrue(accounts.getByUsernameLinkHandle(linkHandle1).isEmpty());
assertTrue(accounts.getByUsernameLinkHandle(linkHandle2).isEmpty());
}
@Test
public void testUsernameLinksViaAccountsManager() throws Exception {
final AccountsManager accountsManager = new AccountsManager(
accounts,
mock(PhoneNumberIdentifiers.class),
mock(FaultTolerantRedisCluster.class),
mock(DeletedAccountsManager.class),
mock(Keys.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
mock(ExperimentEnrollmentManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(Clock.class));
final Account account = nextRandomAccount();
account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account);
final UUID linkHandle = UUID.randomUUID();
final byte[] encruptedUsername = RandomUtils.nextBytes(32);
accountsManager.update(account, a -> a.setUsernameLinkDetails(linkHandle, encruptedUsername));
final Optional<Account> maybeAccount = accountsManager.getByUsernameLinkHandle(linkHandle);
assertTrue(maybeAccount.isPresent());
assertTrue(maybeAccount.get().getEncryptedUsername().isPresent());
assertArrayEquals(encruptedUsername, maybeAccount.get().getEncryptedUsername().get());
// making some unrelated change and updating account to check that username link data is still there
final Optional<Account> accountToChange = accountsManager.getByAccountIdentifier(account.getUuid());
assertTrue(accountToChange.isPresent());
accountsManager.update(accountToChange.get(), a -> a.setDiscoverableByPhoneNumber(!a.isDiscoverableByPhoneNumber()));
final Optional<Account> accountAfterChange = accountsManager.getByUsernameLinkHandle(linkHandle);
assertTrue(accountAfterChange.isPresent());
assertTrue(accountAfterChange.get().getEncryptedUsername().isPresent());
assertArrayEquals(encruptedUsername, accountAfterChange.get().getEncryptedUsername().get());
// now deleting the link
final Optional<Account> accountToDeleteLink = accountsManager.getByAccountIdentifier(account.getUuid());
accountsManager.update(accountToDeleteLink.get(), a -> a.setUsernameLinkDetails(null, null));
assertTrue(accounts.getByUsernameLinkHandle(linkHandle).isEmpty());
}
@Test
void testStore() {
Device device = generateDevice(1);
@ -818,19 +913,24 @@ class AccountsTest {
assertThat(account.getUsernameHash()).isEmpty();
}
private Device generateDevice(long id) {
private static Device generateDevice(long id) {
return DevicesHelper.createDevice(id);
}
private Account generateAccount(String number, UUID uuid, final UUID pni) {
private static Account nextRandomAccount() {
final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement());
return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID());
}
private static Account generateAccount(String number, UUID uuid, final UUID pni) {
Device device = generateDevice(1);
return generateAccount(number, uuid, pni, List.of(device));
}
private Account generateAccount(String number, UUID uuid, final UUID pni, List<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
private static Account generateAccount(String number, UUID uuid, final UUID pni, List<Device> devices) {
final byte[] unidentifiedAccessKey = new byte[16];
final Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte) random.nextInt(255));
return AccountsHelper.generateTestAccount(number, uuid, pni, devices, unidentifiedAccessKey);
}

View File

@ -23,11 +23,29 @@ public final class DynamoDbExtensionSchema {
ACCOUNTS("accounts_test",
Accounts.KEY_ACCOUNT_UUID,
null,
List.of(AttributeDefinition.builder()
.attributeName(Accounts.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(), List.of()),
List.of(
AttributeDefinition.builder()
.attributeName(Accounts.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build(),
AttributeDefinition.builder()
.attributeName(Accounts.ATTR_USERNAME_LINK_UUID)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(
GlobalSecondaryIndex.builder()
.indexName(Accounts.USERNAME_LINK_TO_UUID_INDEX)
.keySchema(
KeySchemaElement.builder()
.attributeName(Accounts.ATTR_USERNAME_LINK_UUID)
.keyType(KeyType.HASH)
.build()
)
.projection(Projection.builder().projectionType(ProjectionType.KEYS_ONLY).build())
.provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(10L).writeCapacityUnits(10L).build())
.build()
),
List.of()),
DELETED_ACCOUNTS("deleted_accounts_test",
DeletedAccounts.KEY_ACCOUNT_E164,

View File

@ -114,6 +114,7 @@ public class AccountsHelper {
case "getPhoneNumberIdentifier" -> when(updatedAccount.getPhoneNumberIdentifier()).thenAnswer(stubbing);
case "getNumber" -> when(updatedAccount.getNumber()).thenAnswer(stubbing);
case "getUsername" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing);
case "getUsernameHash" -> when(updatedAccount.getUsernameHash()).thenAnswer(stubbing);
case "getDevices" -> when(updatedAccount.getDevices()).thenAnswer(stubbing);
case "getDevice" -> when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "getMasterDevice" -> when(updatedAccount.getMasterDevice()).thenAnswer(stubbing);

View File

@ -11,7 +11,7 @@ import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.lang3.RandomUtils;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
@ -50,7 +50,20 @@ public final class MockUtils {
final RateLimiters.For handle,
final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).forDescriptor(eq(handle));
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doNothing().when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doNothing().when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
@ -73,6 +86,21 @@ public final class MockUtils {
}
}
public static void updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.For handle,
final UUID input,
final Duration retryAfter,
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
throw new RuntimeException(e);
}
}
public static SecretBytes randomSecretBytes(final int size) {
return new SecretBytes(RandomUtils.nextBytes(size));
}