Moving Account serialization logic to storage-specific classes

This commit is contained in:
Sergey Skrobotov 2023-07-20 11:12:44 -07:00
parent f5c57e5741
commit cf92007f66
7 changed files with 154 additions and 69 deletions

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonFilter;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
@ -29,12 +30,12 @@ import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter; import org.whispersystems.textsecuregcm.util.IdentityKeyAdapter;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@JsonFilter("Account")
public class Account { public class Account {
@JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class); private static final Logger logger = LoggerFactory.getLogger(Account.class);
@JsonIgnore @JsonProperty
private UUID uuid; private UUID uuid;
@JsonProperty("pni") @JsonProperty("pni")
@ -55,7 +56,7 @@ public class Account {
@Nullable @Nullable
private byte[] reservedUsernameHash; private byte[] reservedUsernameHash;
@JsonIgnore @JsonProperty
@Nullable @Nullable
private UUID usernameLinkHandle; private UUID usernameLinkHandle;
@ -103,16 +104,13 @@ public class Account {
@JsonIgnore @JsonIgnore
private boolean stale; private boolean stale;
@JsonIgnore
private boolean canonicallyDiscoverable;
public UUID getUuid() { public UUID getUuid() {
// this is the one method that may be called on a stale account // this is the one method that may be called on a stale account
return uuid; return uuid;
} }
public void setUuid(UUID uuid) { public void setUuid(final UUID uuid) {
requireNotStale(); requireNotStale();
this.uuid = uuid; this.uuid = uuid;
@ -140,7 +138,7 @@ public class Account {
return number; return number;
} }
public void setNumber(String number, UUID phoneNumberIdentifier) { public void setNumber(final String number, final UUID phoneNumberIdentifier) {
requireNotStale(); requireNotStale();
this.number = number; this.number = number;
@ -203,14 +201,14 @@ public class Account {
this.usernameLinkHandle = usernameLinkHandle; this.usernameLinkHandle = usernameLinkHandle;
} }
public void addDevice(Device device) { public void addDevice(final Device device) {
requireNotStale(); requireNotStale();
removeDevice(device.getId()); removeDevice(device.getId());
this.devices.add(device); this.devices.add(device);
} }
public void removeDevice(long deviceId) { public void removeDevice(final long deviceId) {
requireNotStale(); requireNotStale();
this.devices.removeIf(device -> device.getId() == deviceId); this.devices.removeIf(device -> device.getId() == deviceId);
@ -228,7 +226,7 @@ public class Account {
return getDevice(Device.MASTER_ID); return getDevice(Device.MASTER_ID);
} }
public Optional<Device> getDevice(long deviceId) { public Optional<Device> getDevice(final long deviceId) {
requireNotStale(); requireNotStale();
return devices.stream().filter(device -> device.getId() == deviceId).findFirst(); return devices.stream().filter(device -> device.getId() == deviceId).findFirst();
@ -278,7 +276,7 @@ public class Account {
return allEnabledDevicesHaveCapability(DeviceCapabilities::isPaymentActivation); return allEnabledDevicesHaveCapability(DeviceCapabilities::isPaymentActivation);
} }
private boolean allEnabledDevicesHaveCapability(Predicate<DeviceCapabilities> predicate) { private boolean allEnabledDevicesHaveCapability(final Predicate<DeviceCapabilities> predicate) {
requireNotStale(); requireNotStale();
return devices.stream() return devices.stream()
@ -309,25 +307,13 @@ public class Account {
int count = 0; int count = 0;
for (Device device : devices) { for (final Device device : devices) {
if (device.isEnabled()) count++; if (device.isEnabled()) count++;
} }
return count; return count;
} }
public boolean isCanonicallyDiscoverable() {
requireNotStale();
return canonicallyDiscoverable;
}
public void setCanonicallyDiscoverable(boolean canonicallyDiscoverable) {
requireNotStale();
this.canonicallyDiscoverable = canonicallyDiscoverable;
}
public void setIdentityKey(final IdentityKey identityKey) { public void setIdentityKey(final IdentityKey identityKey) {
requireNotStale(); requireNotStale();
@ -362,7 +348,7 @@ public class Account {
return Optional.ofNullable(currentProfileVersion); return Optional.ofNullable(currentProfileVersion);
} }
public void setCurrentProfileVersion(String currentProfileVersion) { public void setCurrentProfileVersion(final String currentProfileVersion) {
requireNotStale(); requireNotStale();
this.currentProfileVersion = currentProfileVersion; this.currentProfileVersion = currentProfileVersion;
@ -374,7 +360,7 @@ public class Account {
return badges; return badges;
} }
public void setBadges(Clock clock, List<AccountBadge> badges) { public void setBadges(final Clock clock, final List<AccountBadge> badges) {
requireNotStale(); requireNotStale();
this.badges = badges; this.badges = badges;
@ -382,11 +368,11 @@ public class Account {
purgeStaleBadges(clock); purgeStaleBadges(clock);
} }
public void addBadge(Clock clock, AccountBadge badge) { public void addBadge(final Clock clock, final AccountBadge badge) {
requireNotStale(); requireNotStale();
boolean added = false; boolean added = false;
for (int i = 0; i < badges.size(); i++) { for (int i = 0; i < badges.size(); i++) {
AccountBadge badgeInList = badges.get(i); final AccountBadge badgeInList = badges.get(i);
if (Objects.equals(badgeInList.getId(), badge.getId())) { if (Objects.equals(badgeInList.getId(), badge.getId())) {
if (added) { if (added) {
badges.remove(i); badges.remove(i);
@ -405,7 +391,7 @@ public class Account {
purgeStaleBadges(clock); purgeStaleBadges(clock);
} }
public void makeBadgePrimaryIfExists(Clock clock, String badgeId) { public void makeBadgePrimaryIfExists(final Clock clock, final String badgeId) {
requireNotStale(); requireNotStale();
// early exit if it's already the first item in the list // early exit if it's already the first item in the list
@ -429,28 +415,28 @@ public class Account {
purgeStaleBadges(clock); purgeStaleBadges(clock);
} }
public void removeBadge(Clock clock, String id) { public void removeBadge(final Clock clock, final String id) {
requireNotStale(); requireNotStale();
badges.removeIf(accountBadge -> Objects.equals(accountBadge.getId(), id)); badges.removeIf(accountBadge -> Objects.equals(accountBadge.getId(), id));
purgeStaleBadges(clock); purgeStaleBadges(clock);
} }
private void purgeStaleBadges(Clock clock) { private void purgeStaleBadges(final Clock clock) {
final Instant now = clock.instant(); final Instant now = clock.instant();
badges.removeIf(accountBadge -> now.isAfter(accountBadge.getExpiration())); badges.removeIf(accountBadge -> now.isAfter(accountBadge.getExpiration()));
} }
public void setRegistrationLockFromAttributes(final AccountAttributes attributes) { public void setRegistrationLockFromAttributes(final AccountAttributes attributes) {
if (!Util.isEmpty(attributes.getRegistrationLock())) { if (!Util.isEmpty(attributes.getRegistrationLock())) {
SaltedTokenHash credentials = SaltedTokenHash.generateFor(attributes.getRegistrationLock()); final SaltedTokenHash credentials = SaltedTokenHash.generateFor(attributes.getRegistrationLock());
setRegistrationLock(credentials.hash(), credentials.salt()); setRegistrationLock(credentials.hash(), credentials.salt());
} else { } else {
setRegistrationLock(null, null); setRegistrationLock(null, null);
} }
} }
public void setRegistrationLock(String registrationLock, String registrationLockSalt) { public void setRegistrationLock(final String registrationLock, final String registrationLockSalt) {
requireNotStale(); requireNotStale();
this.registrationLock = registrationLock; this.registrationLock = registrationLock;
@ -469,7 +455,7 @@ public class Account {
return Optional.ofNullable(unidentifiedAccessKey); return Optional.ofNullable(unidentifiedAccessKey);
} }
public void setUnidentifiedAccessKey(byte[] unidentifiedAccessKey) { public void setUnidentifiedAccessKey(final byte[] unidentifiedAccessKey) {
requireNotStale(); requireNotStale();
this.unidentifiedAccessKey = unidentifiedAccessKey; this.unidentifiedAccessKey = unidentifiedAccessKey;
@ -481,7 +467,7 @@ public class Account {
return unrestrictedUnidentifiedAccess; return unrestrictedUnidentifiedAccess;
} }
public void setUnrestrictedUnidentifiedAccess(boolean unrestrictedUnidentifiedAccess) { public void setUnrestrictedUnidentifiedAccess(final boolean unrestrictedUnidentifiedAccess) {
requireNotStale(); requireNotStale();
this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess; this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess;
@ -511,7 +497,7 @@ public class Account {
return version; return version;
} }
public void setVersion(int version) { public void setVersion(final int version) {
requireNotStale(); requireNotStale();
this.version = version; this.version = version;

View File

@ -8,6 +8,7 @@ import static com.codahale.metrics.MetricRegistry.name;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables; import com.google.common.base.Throwables;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
@ -31,7 +32,6 @@ import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
@ -63,11 +63,25 @@ import software.amazon.awssdk.services.dynamodb.model.Update;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.CompletableFutureUtils;
/**
* "Accounts" DDB table's structure doesn't match 1:1 the {@link Account} class: most of the class fields are serialized
* and stored in the {@link Accounts#ATTR_ACCOUNT_DATA} attribute, however there are certain fields that are stored only as DDB attributes
* (e.g. if indexing or lookup by field is required), and there are also fields that stored in both places.
* This class contains all the logic that decides whether or not a field of the {@link Account} class should be
* added as an attribute, serialized as a part of {@link Accounts#ATTR_ACCOUNT_DATA}, or both. To skip serialization,
* make sure attribute name is listed in {@link Accounts#ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION}. If serialization is skipped,
* make sure the field is stored in a DDB attribute and then put back into the account object in {@link Accounts#fromItem(Map)}.
*/
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class Accounts extends AbstractDynamoDbStore { public class Accounts extends AbstractDynamoDbStore {
private static final Logger log = LoggerFactory.getLogger(Accounts.class); private static final Logger log = LoggerFactory.getLogger(Accounts.class);
static final List<String> ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION = List.of("uuid", "usernameLinkHandle");
private static final ObjectWriter ACCOUNT_DDB_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION));
private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create")); private static final Timer CREATE_TIMER = Metrics.timer(name(Accounts.class, "create"));
private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber")); private static final Timer CHANGE_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "changeNumber"));
private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername")); private static final Timer SET_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "setUsername"));
@ -207,7 +221,7 @@ public class Accounts extends AbstractDynamoDbStore {
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow(); final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
// It's up to the client to delete this username hash if they can't retrieve and decrypt the plaintext username from storage service // It's up to the client to delete this username hash if they can't retrieve and decrypt the plaintext username from storage service
existingAccount.getUsernameHash().ifPresent(existingUsernameHash -> account.setUsernameHash(existingUsernameHash)); existingAccount.getUsernameHash().ifPresent(account::setUsernameHash);
account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier()); account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier());
account.setVersion(existingAccount.getVersion()); account.setVersion(existingAccount.getVersion());
@ -281,7 +295,7 @@ public class Accounts extends AbstractDynamoDbStore {
"#version", ATTR_VERSION)) "#version", ATTR_VERSION))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":number", numberAttr, ":number", numberAttr,
":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":pni", pniAttr, ":pni", pniAttr,
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
@ -324,7 +338,7 @@ public class Accounts extends AbstractDynamoDbStore {
final long expirationTime = clock.instant().plus(ttl).getEpochSecond(); final long expirationTime = clock.instant().plus(ttl).getEpochSecond();
// Use account UUID as a "reservation token" - by providing this, the client proves ownership of the hash // Use account UUID as a "reservation token" - by providing this, the client proves ownership of the hash
UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
try { try {
final List<TransactWriteItem> writeItems = new ArrayList<>(); final List<TransactWriteItem> writeItems = new ArrayList<>();
@ -352,7 +366,7 @@ public class Accounts extends AbstractDynamoDbStore {
.conditionExpression("#version = :version") .conditionExpression("#version = :version")
.expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA, "#version", ATTR_VERSION)) .expressionAttributeNames(Map.of("#data", ATTR_ACCOUNT_DATA, "#version", ATTR_VERSION))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ":data", accountDataAttributeValue(account),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))) ":version_increment", AttributeValues.fromInt(1)))
.build()) .build())
@ -427,7 +441,7 @@ public class Accounts extends AbstractDynamoDbStore {
final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash"); final StringBuilder updateExpr = new StringBuilder("SET #data = :data, #username_hash = :username_hash");
final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of( final Map<String, AttributeValue> expressionAttributeValues = new HashMap<>(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ":data", accountDataAttributeValue(account),
":username_hash", AttributeValues.fromByteArray(usernameHash), ":username_hash", AttributeValues.fromByteArray(usernameHash),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))); ":version_increment", AttributeValues.fromInt(1)));
@ -503,7 +517,7 @@ public class Accounts extends AbstractDynamoDbStore {
"#username_hash", ATTR_USERNAME_HASH, "#username_hash", ATTR_USERNAME_HASH,
"#version", ATTR_VERSION)) "#version", ATTR_VERSION))
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ":data", accountDataAttributeValue(account),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))) ":version_increment", AttributeValues.fromInt(1)))
.build()) .build())
@ -547,8 +561,9 @@ public class Accounts extends AbstractDynamoDbStore {
"#data", ATTR_ACCOUNT_DATA, "#data", ATTR_ACCOUNT_DATA,
"#cds", ATTR_CANONICALLY_DISCOVERABLE, "#cds", ATTR_CANONICALLY_DISCOVERABLE,
"#version", ATTR_VERSION)); "#version", ATTR_VERSION));
final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of( final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))); ":version_increment", AttributeValues.fromInt(1)));
@ -861,7 +876,7 @@ public class Accounts extends AbstractDynamoDbStore {
KEY_ACCOUNT_UUID, uuidAttr, KEY_ACCOUNT_UUID, uuidAttr,
ATTR_ACCOUNT_E164, numberAttr, ATTR_ACCOUNT_E164, numberAttr,
ATTR_PNI_UUID, pniUuidAttr, ATTR_PNI_UUID, pniUuidAttr,
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), ATTR_ACCOUNT_DATA, accountDataAttributeValue(account),
ATTR_VERSION, AttributeValues.fromInt(account.getVersion()), ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory()))); ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory())));
@ -970,10 +985,10 @@ public class Accounts extends AbstractDynamoDbStore {
@VisibleForTesting @VisibleForTesting
@Nonnull @Nonnull
static Account fromItem(final Map<String, AttributeValue> item) { static Account fromItem(final Map<String, AttributeValue> item) {
// TODO: eventually require ATTR_CANONICALLY_DISCOVERABLE
if (!item.containsKey(ATTR_ACCOUNT_DATA) if (!item.containsKey(ATTR_ACCOUNT_DATA)
|| !item.containsKey(ATTR_ACCOUNT_E164) || !item.containsKey(ATTR_ACCOUNT_E164)
|| !item.containsKey(KEY_ACCOUNT_UUID)) { || !item.containsKey(KEY_ACCOUNT_UUID)
|| !item.containsKey(ATTR_CANONICALLY_DISCOVERABLE)) {
throw new RuntimeException("item missing values"); throw new RuntimeException("item missing values");
} }
try { try {
@ -994,9 +1009,6 @@ public class Accounts extends AbstractDynamoDbStore {
account.setUsernameHash(AttributeValues.getByteArray(item, ATTR_USERNAME_HASH, null)); account.setUsernameHash(AttributeValues.getByteArray(item, ATTR_USERNAME_HASH, null));
account.setUsernameLinkHandle(AttributeValues.getUUID(item, ATTR_USERNAME_LINK_UUID, null)); account.setUsernameLinkHandle(AttributeValues.getUUID(item, ATTR_USERNAME_LINK_UUID, null));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n())); account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
account.setCanonicallyDiscoverable(Optional.ofNullable(item.get(ATTR_CANONICALLY_DISCOVERABLE))
.map(AttributeValue::bool)
.orElse(false));
return account; return account;
@ -1005,6 +1017,10 @@ public class Accounts extends AbstractDynamoDbStore {
} }
} }
private static AttributeValue accountDataAttributeValue(final Account account) throws JsonProcessingException {
return AttributeValues.fromByteArray(ACCOUNT_DDB_JSON_WRITER.writeValueAsBytes(account));
}
private static boolean conditionalCheckFailed(final CancellationReason reason) { private static boolean conditionalCheckFailed(final CancellationReason reason) {
return CONDITIONAL_CHECK_FAILED.equals(reason.code()); return CONDITIONAL_CHECK_FAILED.equals(reason.code());
} }

View File

@ -12,7 +12,7 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
@ -112,7 +112,8 @@ public class AccountsManager {
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final Clock clock; private final Clock clock;
private static final ObjectMapper mapper = SystemMapper.jsonMapper(); private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
// An account that's used at least daily will get reset in the cache at least once per day when its "last seen" // An account that's used at least daily will get reset in the cache at least once per day when its "last seen"
// timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read // timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read
@ -454,7 +455,7 @@ public class AccountsManager {
/** /**
* Reserve a username hash so that no other accounts may take it. * Reserve a username hash so that no other accounts may take it.
* <p> * <p>
* The reserved hash can later be set with {@link #confirmReservedUsernameHash(Account, byte[])}. The reservation * The reserved hash can later be set with {@link #confirmReservedUsernameHash(Account, byte[], byte[])}. The reservation
* will eventually expire, after which point confirmReservedUsernameHash may fail if another account has taken the * will eventually expire, after which point confirmReservedUsernameHash may fail if another account has taken the
* username hash. * username hash.
* *
@ -657,7 +658,7 @@ public class AccountsManager {
final Supplier<Account> retriever, final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException { final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
Account originalAccount = cloneAccount(account); Account originalAccount = cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return account; return account;
@ -671,7 +672,7 @@ public class AccountsManager {
try { try {
persister.persistAccount(account); persister.persistAccount(account);
final Account updatedAccount = cloneAccount(account); final Account updatedAccount = cloneAccountAsNotStale(account);
account.markStale(); account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount); changeValidator.validateChange(originalAccount, updatedAccount);
@ -681,7 +682,7 @@ public class AccountsManager {
tries++; tries++;
account = retriever.get(); account = retriever.get();
originalAccount = cloneAccount(account); originalAccount = cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return account; return account;
@ -699,7 +700,7 @@ public class AccountsManager {
final AccountChangeValidator changeValidator, final AccountChangeValidator changeValidator,
final int remainingTries) { final int remainingTries) {
final Account originalAccount = cloneAccount(account); final Account originalAccount = cloneAccountAsNotStale(account);
if (!updater.apply(account)) { if (!updater.apply(account)) {
return CompletableFuture.completedFuture(account); return CompletableFuture.completedFuture(account);
@ -708,7 +709,7 @@ public class AccountsManager {
if (remainingTries > 0) { if (remainingTries > 0) {
return persister.apply(account) return persister.apply(account)
.thenApply(ignored -> { .thenApply(ignored -> {
final Account updatedAccount = cloneAccount(account); final Account updatedAccount = cloneAccountAsNotStale(account);
account.markStale(); account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount); changeValidator.validateChange(originalAccount, updatedAccount);
@ -728,13 +729,10 @@ public class AccountsManager {
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
} }
private static Account cloneAccount(final Account account) { private static Account cloneAccountAsNotStale(final Account account) {
try { try {
final Account clone = mapper.readValue(mapper.writeValueAsBytes(account), Account.class); return SystemMapper.jsonMapper().readValue(
clone.setUuid(account.getUuid()); SystemMapper.jsonMapper().writeValueAsBytes(account), Account.class);
clone.setUsernameLinkHandle(account.getUsernameLinkHandle());
return clone;
} catch (final IOException e) { } catch (final IOException e) {
// this should really, truly, never happen // this should really, truly, never happen
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
@ -901,7 +899,7 @@ public class AccountsManager {
private void redisSet(Account account) { private void redisSet(Account account) {
try (Timer.Context ignored = redisSetTimer.time()) { try (Timer.Context ignored = redisSetTimer.time()) {
final String accountJson = mapper.writeValueAsString(account); final String accountJson = ACCOUNT_REDIS_JSON_WRITER.writeValueAsString(account);
cacheCluster.useCluster(connection -> { cacheCluster.useCluster(connection -> {
final RedisAdvancedClusterCommands<String, String> commands = connection.sync(); final RedisAdvancedClusterCommands<String, String> commands = connection.sync();
@ -922,7 +920,7 @@ public class AccountsManager {
final String accountJson; final String accountJson;
try { try {
accountJson = mapper.writeValueAsString(account); accountJson = ACCOUNT_REDIS_JSON_WRITER.writeValueAsString(account);
} catch (final JsonProcessingException e) { } catch (final JsonProcessingException e) {
throw new UncheckedIOException(e); throw new UncheckedIOException(e);
} }
@ -1036,7 +1034,7 @@ public class AccountsManager {
private static Optional<Account> parseAccountJson(@Nullable final String accountJson, final UUID uuid) { private static Optional<Account> parseAccountJson(@Nullable final String accountJson, final UUID uuid) {
try { try {
if (StringUtils.isNotBlank(accountJson)) { if (StringUtils.isNotBlank(accountJson)) {
Account account = mapper.readValue(accountJson, Account.class); Account account = SystemMapper.jsonMapper().readValue(accountJson, Account.class);
account.setUuid(uuid); account.setUuid(uuid);
if (account.getPhoneNumberIdentifier() == null) { if (account.getPhoneNumberIdentifier() == null) {

View File

@ -6,12 +6,19 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonFilter;
import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ser.FilterProvider;
import com.fasterxml.jackson.databind.ser.impl.SimpleBeanPropertyFilter;
import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider;
import com.fasterxml.jackson.dataformat.yaml.YAMLMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretsModule; import org.whispersystems.textsecuregcm.configuration.secrets.SecretsModule;
@ -34,6 +41,7 @@ public class SystemMapper {
public static ObjectMapper configureMapper(final ObjectMapper mapper) { public static ObjectMapper configureMapper(final ObjectMapper mapper) {
return mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) return mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.setFilterProvider(new SimpleFilterProvider().setDefaultFilter(SimpleBeanPropertyFilter.serializeAll()))
.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE) .setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE)
.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY) .setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY)
.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.PUBLIC_ONLY) .setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.PUBLIC_ONLY)
@ -42,4 +50,23 @@ public class SystemMapper {
new JavaTimeModule(), new JavaTimeModule(),
new Jdk8Module()); new Jdk8Module());
} }
public static FilterProvider excludingField(final Class<?> clazz, final List<String> fieldsToExclude) {
final String filterId = clazz.getSimpleName();
// validate that the target class is annotated with @JsonFilter,
final List<JsonFilter> jsonFilterAnnotations = Arrays.stream(clazz.getAnnotations())
.map(a -> a instanceof JsonFilter jsonFilter ? jsonFilter : null)
.filter(Objects::nonNull)
.toList();
if (jsonFilterAnnotations.size() != 1 || !jsonFilterAnnotations.get(0).value().equals(filterId)) {
throw new IllegalStateException("""
Class `%1$s` must have a single annotation of type `JsonFilter`
with the value equal to the name of the class itself: `@JsonFilter("%1$s")`
""".formatted(filterId));
}
return new SimpleFilterProvider()
.addFilter(filterId, SimpleBeanPropertyFilter.serializeAllExcept(fieldsToExclude.toArray(new String[0])));
}
} }

View File

@ -16,11 +16,15 @@ import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.createDevice; import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.createDevice;
import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.setEnabled; import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.setEnabled;
import com.fasterxml.jackson.annotation.JsonFilter;
import java.lang.annotation.Annotation;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -428,4 +432,16 @@ class AccountTest {
assertThat(badge.isVisible()).isTrue(); assertThat(badge.isVisible()).isTrue();
}); });
} }
@Test
public void testAccountClassJsonFilterIdMatchesClassName() throws Exception {
// Some logic relies on the @JsonFilter name being equal to the class name.
// This test is just making sure that annotation is there and that the ID matches class name.
final Optional<Annotation> maybeJsonFilterAnnotation = Arrays.stream(Account.class.getAnnotations())
.filter(a -> a.annotationType().equals(JsonFilter.class))
.findFirst();
assertTrue(maybeJsonFilterAnnotation.isPresent());
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
}
} }

View File

@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -989,6 +990,24 @@ class AccountsTest {
assertThat(account.getUsernameHash()).isEmpty(); assertThat(account.getUsernameHash()).isEmpty();
} }
@Test
public void testIgnoredFieldsNotAddedToDataAttribute() throws Exception {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
account.setUsernameHash(RandomUtils.nextBytes(32));
account.setUsernameLinkDetails(UUID.randomUUID(), RandomUtils.nextBytes(32));
accounts.create(account);
final Map<String, AttributeValue> accountRecord = DYNAMO_DB_EXTENSION.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.build())
.item();
final Map<?, ?> dataMap = SystemMapper.jsonMapper()
.readValue(accountRecord.get(Accounts.ATTR_ACCOUNT_DATA).b().asByteArray(), Map.class);
Accounts.ACCOUNT_FIELDS_TO_EXCLUDE_FROM_SERIALIZATION
.forEach(field -> assertFalse(dataMap.containsKey(field)));
}
private static Device generateDevice(long id) { private static Device generateDevice(long id) {
return DevicesHelper.createDevice(id); return DevicesHelper.createDevice(id);
} }

View File

@ -6,13 +6,19 @@
package org.whispersystems.textsecuregcm.util; package org.whispersystems.textsecuregcm.util;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import com.fasterxml.jackson.annotation.JsonFilter;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import java.util.List;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
@ -68,7 +74,6 @@ class SystemMapperTest {
} }
} }
@ParameterizedTest @ParameterizedTest
@ValueSource(classes = {DataClass.class, DataRecord.class}) @ValueSource(classes = {DataClass.class, DataRecord.class})
public void testOptionalField(final Class<? extends Data> clazz) throws Exception { public void testOptionalField(final Class<? extends Data> clazz) throws Exception {
@ -96,4 +101,22 @@ class SystemMapperTest {
Arguments.of(new DataRecord(null), JSON_NO_FIELD) Arguments.of(new DataRecord(null), JSON_NO_FIELD)
); );
} }
public record NotAnnotatedWithJsonFilter(String data) {
}
@JsonFilter("AnnotatedWithJsonFilter")
public record AnnotatedWithJsonFilter(String data, String excluded) {
}
@Test
public void testFiltering() throws Exception {
assertThrows(IllegalStateException.class, () -> SystemMapper.excludingField(NotAnnotatedWithJsonFilter.class, List.of("data")));
final ObjectWriter writer = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(AnnotatedWithJsonFilter.class, List.of("excluded")));
final AnnotatedWithJsonFilter obj = new AnnotatedWithJsonFilter("valData", "valExcluded");
final String json = writer.writeValueAsString(obj);
final Map<?, ?> serializedFields = SystemMapper.jsonMapper().readValue(json, Map.class);
assertEquals(Map.of("data", "valData"), serializedFields);
}
} }