Add optimistic locking to account updates

This commit is contained in:
Chris Eager 2021-07-07 11:54:22 -05:00 committed by Jon Chambers
parent 62022c7de1
commit 158d65c6a7
30 changed files with 1397 additions and 399 deletions

View File

@ -9,7 +9,6 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.basic.BasicCredentials;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.time.Clock;
@ -118,8 +117,7 @@ public class BaseAccountAuthenticator {
Metrics.summary(DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME, IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isMaster()))
.record(Duration.ofMillis(todayInMillisWithOffset - device.getLastSeen()).toDays());
device.setLastSeen(Util.todayInMillis(clock));
accountsManager.update(account);
accountsManager.updateDevice(account, device.getId(), d -> d.setLastSeen(Util.todayInMillis(clock)));
}
}

View File

@ -7,15 +7,15 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import java.time.Duration;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
public class CircuitBreakerConfiguration {
@JsonProperty
@ -39,6 +39,9 @@ public class CircuitBreakerConfiguration {
@Min(1)
private long waitDurationInOpenStateInSeconds = 10;
@JsonProperty
private List<String> ignoredExceptions = Collections.emptyList();
public int getFailureRateThreshold() {
return failureRateThreshold;
@ -56,6 +59,18 @@ public class CircuitBreakerConfiguration {
return waitDurationInOpenStateInSeconds;
}
public List<Class> getIgnoredExceptions() {
return ignoredExceptions.stream()
.map(name -> {
try {
return Class.forName(name);
} catch (final ClassNotFoundException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
}
@VisibleForTesting
public void setFailureRateThreshold(int failureRateThreshold) {
this.failureRateThreshold = failureRateThreshold;
@ -76,9 +91,15 @@ public class CircuitBreakerConfiguration {
this.waitDurationInOpenStateInSeconds = seconds;
}
@VisibleForTesting
public void setIgnoredExceptions(final List<String> ignoredExceptions) {
this.ignoredExceptions = ignoredExceptions;
}
public CircuitBreakerConfig toCircuitBreakerConfig() {
return CircuitBreakerConfig.custom()
.failureRateThreshold(getFailureRateThreshold())
.ignoreExceptions(getIgnoredExceptions().toArray(new Class[0]))
.ringBufferSizeInHalfOpenState(getRingBufferSizeInHalfOpenState())
.waitDurationInOpenState(Duration.ofSeconds(getWaitDurationInOpenStateInSeconds()))
.ringBufferSizeInClosedState(getRingBufferSizeInClosedState())

View File

@ -439,12 +439,12 @@ public class AccountController {
return;
}
device.setApnId(null);
device.setVoipApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
device.setFetchesMessages(false);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setVoipApnId(null);
d.setGcmId(registrationId.getGcmRegistrationId());
d.setFetchesMessages(false);
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@ -457,11 +457,12 @@ public class AccountController {
public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setGcmId(null);
device.setFetchesMessages(false);
device.setUserAgent("OWA");
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
d.setFetchesMessages(false);
d.setUserAgent("OWA");
});
directoryQueue.refreshRegisteredUser(account);
}
@ -474,11 +475,12 @@ public class AccountController {
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
device.setApnId(registrationId.getApnRegistrationId());
device.setVoipApnId(registrationId.getVoipRegistrationId());
device.setGcmId(null);
device.setFetchesMessages(false);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(registrationId.getApnRegistrationId());
d.setVoipApnId(registrationId.getVoipRegistrationId());
d.setGcmId(null);
d.setFetchesMessages(false);
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@ -491,15 +493,16 @@ public class AccountController {
public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setFetchesMessages(false);
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
accounts.update(account);
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setFetchesMessages(false);
if (d.getId() == 1) {
d.setUserAgent("OWI");
} else {
d.setUserAgent("OWP");
}
});
directoryQueue.refreshRegisteredUser(account);
}
@ -509,18 +512,18 @@ public class AccountController {
@Path("/registration_lock")
public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) {
AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock());
account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
account.setPin(null);
accounts.update(account);
accounts.update(account, a -> {
a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
a.setPin(null);
});
}
@Timed
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Auth Account account) {
account.setRegistrationLock(null, null);
accounts.update(account);
accounts.update(account, a -> a.setRegistrationLock(null, null));
}
@Timed
@ -531,21 +534,21 @@ public class AccountController {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN set by User-Agent: {}", userAgent);
account.setPin(accountLock.getPin());
account.setRegistrationLock(null, null);
accounts.update(account);
accounts.update(account, a -> {
a.setPin(accountLock.getPin());
a.setRegistrationLock(null, null);
});
}
@Timed
@DELETE
@Path("/pin/")
public void removePin(@Auth Account account, @HeaderParam("User-Agent") String userAgent) {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN removed by User-Agent: {}", userAgent);
account.setPin(null);
accounts.update(account);
accounts.update(account, a -> a.setPin(null));
}
@Timed
@ -553,8 +556,8 @@ public class AccountController {
@Path("/name/")
public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) {
Account account = disabledPermittedAccount.getAccount();
account.getAuthenticatedDevice().get().setName(deviceName.getDeviceName());
accounts.update(account);
Device device = account.getAuthenticatedDevice().get();
accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName()));
}
@Timed
@ -572,25 +575,29 @@ public class AccountController {
@Valid AccountAttributes attributes)
{
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
long deviceId = account.getAuthenticatedDevice().get().getId();
device.setFetchesMessages(attributes.getFetchesMessages());
device.setName(attributes.getName());
device.setLastSeen(Util.todayInMillis());
device.setCapabilities(attributes.getCapabilities());
device.setRegistrationId(attributes.getRegistrationId());
device.setUserAgent(userAgent);
account = accounts.update(account, a-> {
setAccountRegistrationLockFromAttributes(account, attributes);
a.getDevice(deviceId).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent);
});
setAccountRegistrationLockFromAttributes(a, attributes);
a.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
a.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
a.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
});
final boolean hasDiscoverabilityChange = (account.isDiscoverableByPhoneNumber() != attributes.isDiscoverableByPhoneNumber());
account.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
accounts.update(account);
if (hasDiscoverabilityChange) {
directoryQueue.refreshRegisteredUser(account);
}

View File

@ -100,8 +100,7 @@ public class DeviceController {
}
messages.clear(account.getUuid(), deviceId);
account.removeDevice(deviceId);
accounts.update(account);
account = accounts.update(account, a -> a.removeDevice(deviceId));
directoryQueue.refreshRegisteredUser(account);
// ensure any messages that came in after the first clear() are also removed
messages.clear(account.getUuid(), deviceId);
@ -192,15 +191,16 @@ public class DeviceController {
device.setName(accountAttributes.getName());
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
account.get().addDevice(device);
messages.clear(account.get().getUuid(), device.getId());
accounts.update(account.get());
accounts.update(account.get(), a -> {
device.setId(account.get().getNextDeviceId());
messages.clear(account.get().getUuid(), device.getId());
a.addDevice(device);
});;
pendingDevices.remove(number);
@ -224,8 +224,8 @@ public class DeviceController {
@Path("/capabilities")
public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) {
assert(account.getAuthenticatedDevice().isPresent());
account.getAuthenticatedDevice().get().setCapabilities(capabilities);
accounts.update(account);
final long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities));
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@ -104,17 +104,18 @@ public class KeysController {
boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
device.setSignedPreKey(preKeys.getSignedPreKey());
updateAccount = true;
}
if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) {
account.setIdentityKey(preKeys.getIdentityKey());
updateAccount = true;
}
if (updateAccount) {
accounts.update(account);
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> d.setSignedPreKey(preKeys.getSignedPreKey()));
a.setIdentityKey(preKeys.getIdentityKey());
});
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
@ -200,8 +201,7 @@ public class KeysController {
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
device.setSignedPreKey(signedPreKey);
accounts.update(account);
account = accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey));
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);

View File

@ -156,10 +156,11 @@ public class ProfileController {
response = Optional.of(generateAvatarUploadForm(avatar));
}
account.setProfileName(request.getName());
account.setAvatar(avatar);
account.setCurrentProfileVersion(request.getVersion());
accountsManager.update(account);
accountsManager.update(account, a -> {
a.setProfileName(request.getName());
a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion());
});
if (response.isPresent()) return Response.ok(response).build();
else return Response.ok().build();
@ -317,8 +318,7 @@ public class ProfileController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}")
public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
account.setProfileName(name.orElse(null));
accountsManager.update(account);
accountsManager.update(account, a -> a.setProfileName(name.orElse(null)));
}
@Deprecated
@ -382,8 +382,7 @@ public class ProfileController {
.build());
}
account.setAvatar(objectName);
accountsManager.update(account);
accountsManager.update(account, a -> a.setAvatar(objectName));
return profileAvatarUploadAttributes;
}

View File

@ -110,8 +110,8 @@ public class GCMSender {
Device device = account.get().getDevice(message.getDeviceId()).get();
if (device.getUninstalledFeedbackTimestamp() == 0) {
device.setUninstalledFeedbackTimestamp(Util.todayInMillis());
accountsManager.update(account.get());
accountsManager.updateDevice(account.get(), message.getDeviceId(), d ->
d.setUninstalledFeedbackTimestamp(Util.todayInMillis()));
}
}
@ -122,15 +122,11 @@ public class GCMSender {
logger.warn(String.format("Actually received 'CanonicalRegistrationId' ::: (canonical=%s), (original=%s)",
result.getCanonicalRegistrationId(), message.getGcmId()));
Optional<Account> account = getAccountForEvent(message);
if (account.isPresent()) {
//noinspection OptionalGetWithoutIsPresent
Device device = account.get().getDevice(message.getDeviceId()).get();
device.setGcmId(result.getCanonicalRegistrationId());
accountsManager.update(account.get());
}
getAccountForEvent(message).ifPresent(account ->
accountsManager.updateDevice(
account,
message.getDeviceId(),
d -> d.setGcmId(result.getCanonicalRegistrationId())));
canonical.mark();
}

View File

@ -14,11 +14,16 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.security.auth.Subject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
public class Account implements Principal {
@JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class);
@JsonIgnore
private UUID uuid;
@ -58,12 +63,15 @@ public class Account implements Principal {
@JsonProperty("inCds")
private boolean discoverableByPhoneNumber = true;
@JsonProperty("_ddbV")
private int dynamoDbMigrationVersion;
@JsonIgnore
private Device authenticatedDevice;
@JsonProperty
private int version;
@JsonIgnore
private boolean stale;
public Account() {}
@VisibleForTesting
@ -75,47 +83,68 @@ public class Account implements Principal {
}
public Optional<Device> getAuthenticatedDevice() {
requireNotStale();
return Optional.ofNullable(authenticatedDevice);
}
public void setAuthenticatedDevice(Device device) {
requireNotStale();
this.authenticatedDevice = device;
}
public UUID getUuid() {
// this is the one method that may be called on a stale account
return uuid;
}
public void setUuid(UUID uuid) {
requireNotStale();
this.uuid = uuid;
}
public void setNumber(String number) {
requireNotStale();
this.number = number;
}
public String getNumber() {
requireNotStale();
return number;
}
public void addDevice(Device device) {
requireNotStale();
this.devices.remove(device);
this.devices.add(device);
}
public void removeDevice(long deviceId) {
requireNotStale();
this.devices.remove(new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "NA", 0, null));
}
public Set<Device> getDevices() {
requireNotStale();
return devices;
}
public Optional<Device> getMasterDevice() {
requireNotStale();
return getDevice(Device.MASTER_ID);
}
public Optional<Device> getDevice(long deviceId) {
requireNotStale();
for (Device device : devices) {
if (device.getId() == deviceId) {
return Optional.of(device);
@ -126,42 +155,58 @@ public class Account implements Principal {
}
public boolean isGroupsV2Supported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(Device::isGroupsV2Supported);
}
public boolean isStorageSupported() {
requireNotStale();
return devices.stream().anyMatch(device -> device.getCapabilities() != null && device.getCapabilities().isStorage());
}
public boolean isTransferSupported() {
requireNotStale();
return getMasterDevice().map(Device::getCapabilities).map(Device.DeviceCapabilities::isTransfer).orElse(false);
}
public boolean isGv1MigrationSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isGv1Migration());
}
public boolean isSenderKeySupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isSenderKey());
}
public boolean isAnnouncementGroupSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isAnnouncementGroup());
}
public boolean isEnabled() {
requireNotStale();
return getMasterDevice().map(Device::isEnabled).orElse(false);
}
public long getNextDeviceId() {
requireNotStale();
long highestDevice = Device.MASTER_ID;
for (Device device : devices) {
@ -176,6 +221,8 @@ public class Account implements Principal {
}
public int getEnabledDeviceCount() {
requireNotStale();
int count = 0;
for (Device device : devices) {
@ -186,22 +233,32 @@ public class Account implements Principal {
}
public boolean isRateLimited() {
requireNotStale();
return true;
}
public Optional<String> getRelay() {
requireNotStale();
return Optional.empty();
}
public void setIdentityKey(String identityKey) {
requireNotStale();
this.identityKey = identityKey;
}
public String getIdentityKey() {
requireNotStale();
return identityKey;
}
public long getLastSeen() {
requireNotStale();
long lastSeen = 0;
for (Device device : devices) {
@ -214,78 +271,127 @@ public class Account implements Principal {
}
public Optional<String> getCurrentProfileVersion() {
requireNotStale();
return Optional.ofNullable(currentProfileVersion);
}
public void setCurrentProfileVersion(String currentProfileVersion) {
requireNotStale();
this.currentProfileVersion = currentProfileVersion;
}
public String getProfileName() {
requireNotStale();
return name;
}
public void setProfileName(String name) {
requireNotStale();
this.name = name;
}
public String getAvatar() {
requireNotStale();
return avatar;
}
public void setAvatar(String avatar) {
requireNotStale();
this.avatar = avatar;
}
public void setPin(String pin) {
requireNotStale();
this.pin = pin;
}
public void setRegistrationLock(String registrationLock, String registrationLockSalt) {
requireNotStale();
this.registrationLock = registrationLock;
this.registrationLockSalt = registrationLockSalt;
}
public StoredRegistrationLock getRegistrationLock() {
requireNotStale();
return new StoredRegistrationLock(Optional.ofNullable(registrationLock), Optional.ofNullable(registrationLockSalt), Optional.ofNullable(pin), getLastSeen());
}
public Optional<byte[]> getUnidentifiedAccessKey() {
requireNotStale();
return Optional.ofNullable(unidentifiedAccessKey);
}
public void setUnidentifiedAccessKey(byte[] unidentifiedAccessKey) {
requireNotStale();
this.unidentifiedAccessKey = unidentifiedAccessKey;
}
public boolean isUnrestrictedUnidentifiedAccess() {
requireNotStale();
return unrestrictedUnidentifiedAccess;
}
public void setUnrestrictedUnidentifiedAccess(boolean unrestrictedUnidentifiedAccess) {
requireNotStale();
this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess;
}
public boolean isFor(AmbiguousIdentifier identifier) {
requireNotStale();
if (identifier.hasUuid()) return identifier.getUuid().equals(uuid);
else if (identifier.hasNumber()) return identifier.getNumber().equals(number);
else throw new AssertionError();
}
public boolean isDiscoverableByPhoneNumber() {
requireNotStale();
return this.discoverableByPhoneNumber;
}
public void setDiscoverableByPhoneNumber(final boolean discoverableByPhoneNumber) {
requireNotStale();
this.discoverableByPhoneNumber = discoverableByPhoneNumber;
}
public int getDynamoDbMigrationVersion() {
return dynamoDbMigrationVersion;
public int getVersion() {
requireNotStale();
return version;
}
public void setDynamoDbMigrationVersion(int dynamoDbMigrationVersion) {
this.dynamoDbMigrationVersion = dynamoDbMigrationVersion;
public void setVersion(int version) {
requireNotStale();
this.version = version;
}
public void markStale() {
stale = true;
}
private void requireNotStale() {
assert !stale;
//noinspection ConstantConditions
if (stale) {
logger.error("Accessor called on stale account", new RuntimeException());
}
}
// Principal implementation

View File

@ -7,7 +7,7 @@ public interface AccountStore {
boolean create(Account account);
void update(Account account);
void update(Account account) throws ContestedOptimisticLockException;
Optional<Account> get(String number);

View File

@ -13,6 +13,7 @@ import com.codahale.metrics.Timer.Context;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
@ -22,10 +23,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
public class Accounts implements AccountStore {
public static final String ID = "id";
public static final String UID = "uuid";
public static final String ID = "id";
public static final String UID = "uuid";
public static final String NUMBER = "number";
public static final String DATA = "data";
public static final String DATA = "data";
public static final String VERSION = "version";
private static final ObjectMapper mapper = SystemMapper.getMapper();
@ -50,15 +52,19 @@ public class Accounts implements AccountStore {
public boolean create(Account account) {
return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
try (Timer.Context ignored = createTimer.time()) {
UUID uuid = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET data = EXCLUDED.data RETURNING uuid")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapTo(UUID.class)
.findOnly();
final Map<String, Object> resultMap = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET " + DATA + " = EXCLUDED.data, " + VERSION + " = accounts.version + 1 RETURNING uuid, version")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapToMap()
.findOnly();
final UUID uuid = (UUID) resultMap.get(UID);
final int version = (int) resultMap.get(VERSION);
boolean isNew = uuid.equals(account.getUuid());
account.setUuid(uuid);
account.setVersion(version);
return isNew;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
@ -67,13 +73,23 @@ public class Accounts implements AccountStore {
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = updateTimer.time()) {
handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + UID + " = :uuid")
final int newVersion = account.getVersion() + 1;
int rowsModified = handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json), " + VERSION + " = :newVersion WHERE " + UID + " = :uuid AND " + VERSION + " = :version")
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.bind("version", account.getVersion())
.bind("newVersion", newVersion)
.execute();
if (rowsModified == 0) {
throw new ContestedOptimisticLockException();
}
account.setVersion(newVersion);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}

View File

@ -30,16 +30,20 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse;
public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountStore {
@ -49,8 +53,8 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
static final String ATTR_ACCOUNT_E164 = "P";
// account, serialized to JSON
static final String ATTR_ACCOUNT_DATA = "D";
static final String ATTR_MIGRATION_VERSION = "V";
// internal version for optimistic locking
static final String ATTR_VERSION = "V";
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient;
@ -122,11 +126,19 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
ByteBuffer actualAccountUuid = phoneNumberConstraintCancellationReason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer();
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final int version = get(account.getUuid()).get().getVersion();
account.setVersion(version);
update(account);
return false;
}
if ("TransactionConflict".equals(accountCancellationReason.code())) {
// this should only happen during concurrent update()s for an account migration
throw new ContestedOptimisticLockException();
}
// this shouldnt happen
throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e));
}
@ -146,7 +158,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid),
ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
ATTR_MIGRATION_VERSION, AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
ATTR_VERSION, AttributeValues.fromInt(account.getVersion())))
.build())
.build();
}
@ -172,28 +184,44 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
UPDATE_TIMER.record(() -> {
UpdateItemRequest updateItemRequest;
try {
updateItemRequest = UpdateItemRequest.builder()
.tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression("SET #data = :data, #version = :version")
.conditionExpression("attribute_exists(#number)")
.updateExpression("SET #data = :data ADD #version :version_increment")
.conditionExpression("attribute_exists(#number) AND #version = :version")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164,
"#data", ATTR_ACCOUNT_DATA,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)))
.returnValues(ReturnValue.UPDATED_NEW)
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
client.updateItem(updateItemRequest);
try {
UpdateItemResponse response = client.updateItem(updateItemRequest);
account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1));
} catch (final TransactionConflictException e) {
throw new ContestedOptimisticLockException();
} catch (final ConditionalCheckFailedException e) {
// the exception doesnt give details about which condition failed,
// but we can infer it was an optimistic locking failure if the UUID is known
throw get(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e;
}
});
}
@ -343,9 +371,9 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
.conditionExpression("attribute_not_exists(#uuid) OR (attribute_exists(#uuid) AND #version < :version)")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion()))));
":version", AttributeValues.fromInt(account.getVersion()))));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, accountPut).build();
@ -395,6 +423,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class);
account.setNumber(item.get(ATTR_ACCOUNT_E164).s());
account.setUuid(UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer()));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
return account;

View File

@ -26,6 +26,8 @@ import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import net.logstash.logback.argument.StructuredArguments;
import org.apache.commons.lang3.StringUtils;
@ -40,7 +42,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
public class AccountsManager {
@ -119,7 +120,6 @@ public class AccountsManager {
this.mapper = SystemMapper.getMapper();
this.migrationComparisonMapper = mapper.copy();
migrationComparisonMapper.addMixIn(Account.class, AccountComparisonMixin.class);
migrationComparisonMapper.addMixIn(Device.class, DeviceComparisonMixin.class);
this.dynamicConfigurationManager = dynamicConfigurationManager;
@ -169,25 +169,86 @@ public class AccountsManager {
}
}
public void update(Account account) {
public Account update(Account account, Consumer<Account> updater) {
final Account updatedAccount;
try (Timer.Context ignored = updateTimer.time()) {
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
redisSet(account);
databaseUpdate(account);
updater.accept(account);
{
// optimistically increment version
final int originalVersion = account.getVersion();
account.setVersion(originalVersion + 1);
redisSet(account);
account.setVersion(originalVersion);
}
final UUID uuid = account.getUuid();
updatedAccount = updateWithRetries(account, updater, this::databaseUpdate, () -> databaseGet(uuid).get());
if (dynamoWriteEnabled()) {
runSafelyAndRecordMetrics(() -> {
try {
dynamoUpdate(account);
} catch (final ConditionalCheckFailedException e) {
dynamoCreate(account);
final Optional<Account> dynamoAccount = dynamoGet(uuid);
if (dynamoAccount.isPresent()) {
updater.accept(dynamoAccount.get());
Account dynamoUpdatedAccount = updateWithRetries(dynamoAccount.get(),
updater,
this::dynamoUpdate,
() -> dynamoGet(uuid).get());
return Optional.of(dynamoUpdatedAccount);
}
return true;
}, Optional.of(account.getUuid()), true,
(databaseSuccess, dynamoSuccess) -> Optional.empty(), // both values are always true
return Optional.empty();
}, Optional.of(uuid), Optional.of(updatedAccount),
this::compareAccounts,
"update");
}
// set the cache again, so that all updates are coalesced
redisSet(updatedAccount);
}
return updatedAccount;
}
private Account updateWithRetries(Account account, Consumer<Account> updater, Consumer<Account> persister, Supplier<Account> retriever) {
final int maxTries = 10;
int tries = 0;
while (tries < maxTries) {
try {
persister.accept(account);
final Account updatedAccount;
try {
updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
updatedAccount.setUuid(account.getUuid());
} catch (final IOException e) {
// this should really, truly, never happen
throw new IllegalArgumentException(e);
}
account.markStale();
return updatedAccount;
} catch (final ContestedOptimisticLockException e) {
tries++;
account = retriever.get();
updater.accept(account);
}
}
throw new OptimisticLockRetryLimitExceededException();
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> a.getDevice(deviceId).ifPresent(deviceUpdater));
}
public Optional<Account> get(AmbiguousIdentifier identifier) {
@ -445,6 +506,10 @@ public class AccountsManager {
return Optional.of("number");
}
if (databaseAccount.getVersion() != dynamoAccount.getVersion()) {
return Optional.of("version");
}
if (!Objects.equals(databaseAccount.getIdentityKey(), dynamoAccount.getIdentityKey())) {
return Optional.of("identityKey");
}
@ -566,13 +631,6 @@ public class AccountsManager {
.collect(Collectors.joining(" -> "));
}
private static abstract class AccountComparisonMixin extends Account {
@JsonIgnore
private int dynamoDbMigrationVersion;
}
private static abstract class DeviceComparisonMixin extends Device {
@JsonIgnore

View File

@ -0,0 +1,13 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class ContestedOptimisticLockException extends RuntimeException {
public ContestedOptimisticLockException() {
super(null, null, true, false);
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class OptimisticLockRetryLimitExceededException extends RuntimeException {
}

View File

@ -5,20 +5,20 @@
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@ -47,36 +47,42 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
for (Account account : chunkAccounts) {
boolean update = false;
for (Device device : account.getDevices()) {
if (device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis())
{
if (device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis()) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
final Set<Device> devices = account.getDevices();
for (Device device : devices) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
expired.mark();
} else {
device.setUninstalledFeedbackTimestamp(0);
recovered.mark();
}
update = true;
}
}
if (update) {
accountsManager.update(account);
account = accountsManager.update(account, a -> {
for (Device device: a.getDevices()) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
} else {
device.setUninstalledFeedbackTimestamp(0);
}
}
}
});
directoryUpdateAccounts.add(account);
}
}
@ -85,4 +91,13 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts);
}
}
private boolean deviceNeedsUpdate(final Device device) {
return device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
private boolean deviceExpired(final Device device) {
return device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
}

View File

@ -27,6 +27,7 @@ public class AccountRowMapper implements RowMapper<Account> {
Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class);
account.setNumber(resultSet.getString(Accounts.NUMBER));
account.setUuid(UUID.fromString(resultSet.getString(Accounts.UID)));
account.setVersion(resultSet.getInt(Accounts.VERSION));
return account;
} catch (IOException e) {
throw new SQLException(e);

View File

@ -375,4 +375,10 @@
</addColumn>
</changeSet>
<changeSet id="25" author="chris">
<addColumn tableName="accounts">
<column name="version" type="int" defaultValue="0"/>
</addColumn>
</changeSet>
</databaseChangeLog>

View File

@ -50,6 +50,7 @@ import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
class AccountsDynamoDbTest {
@ -211,6 +212,10 @@ class AccountsDynamoDbTest {
verifyStoredState("+14151112222", account.getUuid(), account);
account.setProfileName("name");
accountsDynamoDb.update(account);
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1);
@ -252,13 +257,44 @@ class AccountsDynamoDbTest {
assertThatThrownBy(() -> accountsDynamoDb.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
account.setDynamoDbMigrationVersion(5);
account.setProfileName("name");
accountsDynamoDb.update(account);
assertThat(account.getVersion()).isEqualTo(2);
verifyStoredState("+14151112222", account.getUuid(), account);
account.setVersion(1);
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
account.setVersion(2);
account.setProfileName("name2");
accountsDynamoDb.update(account);
verifyStoredState("+14151112222", account.getUuid(), account);
}
@Test
void testUpdateWithMockTransactionConflictException() {
final DynamoDbClient dynamoDbClient = mock(DynamoDbClient.class);
accountsDynamoDb = new AccountsDynamoDb(dynamoDbClient, mock(DynamoDbAsyncClient.class),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(), NUMBERS_TABLE_NAME, mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
when(dynamoDbClient.updateItem(any(UpdateItemRequest.class)))
.thenThrow(TransactionConflictException.class);
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
}
@Test
void testRetrieveFrom() {
List<Account> users = new ArrayList<>();
@ -463,7 +499,7 @@ class AccountsDynamoDbTest {
assertThat(migrated).isFalse();
verifyStoredState("+14151112222", firstUuid, account);
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
account.setVersion(account.getVersion() + 1);
migrated = accountsDynamoDb.migrate(account).get();
@ -504,8 +540,8 @@ class AccountsDynamoDbTest {
String data = new String(get.item().get(AccountsDynamoDb.ATTR_ACCOUNT_DATA).b().asByteArray(), StandardCharsets.UTF_8);
assertThat(data).isNotEmpty();
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_MIGRATION_VERSION, -1))
.isEqualTo(expecting.getDynamoDbMigrationVersion());
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_VERSION, -1))
.isEqualTo(expecting.getVersion());
Account result = AccountsDynamoDb.fromItem(get.item());
verifyStoredState(number, uuid, result, expecting);
@ -518,6 +554,7 @@ class AccountsDynamoDbTest {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(result.getVersion()).isEqualTo(expecting.getVersion());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {

View File

@ -0,0 +1,274 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit5.EmbeddedPostgresExtension;
import com.opentable.db.postgres.junit5.PreparedDbExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
class AccountsManagerConcurrentModificationIntegrationTest {
@RegisterExtension
static PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private static final String ACCOUNTS_TABLE_NAME = "accounts_test";
private static final String NUMBERS_TABLE_NAME = "numbers_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.hashKey(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build())
.build();
private Accounts accounts;
private AccountsDynamoDb accountsDynamoDb;
private AccountsManager accountsManager;
private RedisAdvancedClusterCommands<String, String> commands;
private Executor mutationExecutor = new ThreadPoolExecutor(20, 20, 5, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20));
@BeforeEach
void setup() {
{
CreateTableRequest createNumbersTableRequest = CreateTableRequest.builder()
.tableName(NUMBERS_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
dynamoDbExtension.getDynamoDbClient().createTable(createNumbersTableRequest);
}
accountsDynamoDb = new AccountsDynamoDb(
dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(),
NUMBERS_TABLE_NAME,
mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
{
final CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration();
circuitBreakerConfiguration.setIgnoredExceptions(List.of("org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException"));
FaultTolerantDatabase faultTolerantDatabase = new FaultTolerantDatabase("accountsTest",
Jdbi.create(db.getTestDatabase()),
circuitBreakerConfiguration);
accounts = new Accounts(faultTolerantDatabase);
}
{
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
final DynamicAccountsDynamoDbMigrationConfiguration config = dynamicConfiguration
.getAccountsDynamoDbMigrationConfiguration();
config.setDeleteEnabled(true);
config.setReadEnabled(true);
config.setWriteEnabled(true);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), anyString())).thenReturn(true);
commands = mock(RedisAdvancedClusterCommands.class);
accountsManager = new AccountsManager(
accounts,
accountsDynamoDb,
RedisClusterHelper.buildMockRedisCluster(commands),
mock(DeletedAccounts.class),
mock(DirectoryQueue.class),
mock(KeysDynamoDb.class),
mock(MessagesManager.class),
mock(UsernamesManager.class),
mock(ProfilesManager.class),
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
experimentEnrollmentManager,
dynamicConfigurationManager);
}
}
@Test
void testConcurrentUpdate() throws IOException {
final UUID uuid = UUID.randomUUID();
accountsManager.create(generateAccount("+14155551212", uuid));
final String profileName = "name";
final String avatar = "avatar";
final boolean discoverableByPhoneNumber = false;
final String currentProfileVersion = "cpv";
final String identityKey = "ikey";
final byte[] unidentifiedAccessKey = new byte[]{1};
final String pin = "1234";
final String registrationLock = "reglock";
final AuthenticationCredentials credentials = new AuthenticationCredentials(registrationLock);
final boolean unrestrictedUnidentifiedAccess = true;
final long lastSeen = Instant.now().getEpochSecond();
CompletableFuture.allOf(
modifyAccount(uuid, account -> account.setProfileName(profileName)),
modifyAccount(uuid, account -> account.setAvatar(avatar)),
modifyAccount(uuid, account -> account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber)),
modifyAccount(uuid, account -> account.setCurrentProfileVersion(currentProfileVersion)),
modifyAccount(uuid, account -> account.setIdentityKey(identityKey)),
modifyAccount(uuid, account -> account.setUnidentifiedAccessKey(unidentifiedAccessKey)),
modifyAccount(uuid, account -> account.setPin(pin)),
modifyAccount(uuid, account -> account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt())),
modifyAccount(uuid, account -> account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setLastSeen(lastSeen)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setName("deviceName"))
).join();
final Account managerAccount = accountsManager.get(uuid).get();
final Account dbAccount = accounts.get(uuid).get();
final Account dynamoAccount = accountsDynamoDb.get(uuid).get();
final Account redisAccount = getLastAccountFromRedisMock(commands);
Stream.of(
new Pair<>("manager", managerAccount),
new Pair<>("db", dbAccount),
new Pair<>("dynamo", dynamoAccount),
new Pair<>("redis", redisAccount)
).forEach(pair ->
verifyAccount(pair.first(), pair.second(), profileName, avatar, discoverableByPhoneNumber,
currentProfileVersion, identityKey, unidentifiedAccessKey, pin, registrationLock,
unrestrictedUnidentifiedAccess, lastSeen)
);
}
private Account getLastAccountFromRedisMock(RedisAdvancedClusterCommands<String, String> commands) throws IOException {
ArgumentCaptor<String> redisSetArgumentCapture = ArgumentCaptor.forClass(String.class);
verify(commands, atLeast(20)).set(anyString(), redisSetArgumentCapture.capture());
return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class);
}
private void verifyAccount(final String name, final Account account, final String profileName, final String avatar, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final String identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAcces, final long lastSeen) {
assertAll(name,
() -> assertEquals(profileName, account.getProfileName()),
() -> assertEquals(avatar, account.getAvatar()),
() -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()),
() -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().get()),
() -> assertEquals(identityKey, account.getIdentityKey()),
() -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().get()),
() -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock, pin)),
() -> assertEquals(unrestrictedUnidentifiedAcces, account.isUnrestrictedUnidentifiedAccess())
);
}
private CompletableFuture<?> modifyAccount(final UUID uuid, final Consumer<Account> accountMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.update(account, accountMutation);
}, mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final long deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.updateDevice(account, deviceId, deviceMutation);
}, mutationExecutor);
}
private Account generateAccount(String number, UUID uuid) {
Device device = generateDevice(1);
return generateAccount(number, uuid, Collections.singleton(device));
}
private Account generateAccount(String number, UUID uuid, Set<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
return new Account(number, uuid, devices, unidentifiedAccessKey);
}
private Device generateDevice(long id) {
Random random = new Random(System.currentTimeMillis());
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt());
return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt() , 0, new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean()));
}
}

View File

@ -5,104 +5,120 @@
package org.whispersystems.textsecuregcm.tests.auth;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BaseAccountAuthenticatorTest {
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
class BaseAccountAuthenticatorTest {
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
@Before
public void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
}
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
@Test
public void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
AccountsHelper.setupMockUpdate(accountsManager);
}
verify(accountsManager, never()).update(acct1);
verify(accountsManager).update(acct2);
@Test
void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
@Test
public void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
verify(accountsManager, never()).update(acct1);
verify(accountsManager, never()).update(acct2);
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(today);
}
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
}
@Test
void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
@Test
public void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager).update(acct1);
verify(accountsManager).update(acct2);
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager, never()).updateDevice(eq(acct2), anyLong(), any());
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(yesterday);
}
@Test
public void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
@Test
void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
baseAccountAuthenticator.updateLastSeen(oldAccount, oldAccount.getDevices().stream().findFirst().get());
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
verify(accountsManager).update(oldAccount);
baseAccountAuthenticator.updateLastSeen(acct1, device1);
baseAccountAuthenticator.updateLastSeen(acct2, device2);
assertThat(oldAccount.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
verify(accountsManager).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
assertThat(device1.getLastSeen()).isEqualTo(today);
assertThat(device2.getLastSeen()).isEqualTo(today);
}
@Test
void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
final Device device = oldAccount.getDevices().stream().findFirst().get();
baseAccountAuthenticator.updateLastSeen(oldAccount, device);
verify(accountsManager).updateDevice(eq(oldAccount), anyLong(), any());
assertThat(device.getLastSeen()).isEqualTo(today);
}
}

View File

@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -171,6 +173,8 @@ class AccountControllerTest {
new SecureRandom().nextBytes(registration_lock_key);
AuthenticationCredentials registrationLockCredentials = new AuthenticationCredentials(Hex.toStringCondensed(registration_lock_key));
AccountsHelper.setupMockUpdate(accountsManager);
when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationDailyLimiter()).thenReturn(rateLimiter);
@ -1352,7 +1356,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("c00lz0rz"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@ -1368,7 +1372,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@ -1385,7 +1389,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@ -1402,7 +1406,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null);
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@ -1419,7 +1423,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("third"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("fourth"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
}
@ -1544,7 +1548,7 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.UNDISCOVERABLE_ACCOUNT);
verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.UNDISCOVERABLE_ACCOUNT));
}
@Test
@ -1557,7 +1561,7 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, false, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.VALID_ACCOUNT);
verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.VALID_ACCOUNT));
}
@Test

View File

@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
@ -47,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -124,6 +126,8 @@ public class DeviceControllerTest {
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty());
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
AccountsHelper.setupMockUpdate(accountsManager);
}
@Test
@ -360,7 +364,7 @@ public class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(AuthHelper.VALID_ACCOUNT);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
@ -15,6 +16,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@ -61,6 +63,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -114,13 +117,15 @@ class KeysControllerTest {
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
Set<Device> allDevices = new HashSet<Device>() {{
Set<Device> allDevices = new HashSet<>() {{
add(sampleDevice);
add(sampleDevice2);
add(sampleDevice3);
add(sampleDevice4);
}};
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
@ -142,7 +147,7 @@ class KeysControllerTest {
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.<Device>empty());
when(existsAccount.getDevice(22L)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey");
@ -256,7 +261,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@Test
@ -271,7 +276,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@ -578,7 +583,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@ -587,7 +592,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.VALID_ACCOUNT);
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
@ -612,7 +617,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@ -621,7 +626,7 @@ class KeysControllerTest {
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any());
}
@Test

View File

@ -20,7 +20,8 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
@ -29,9 +30,10 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.signal.zkgroup.InvalidInputException;
@ -57,13 +59,15 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
public class ProfileControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class ProfileControllerTest {
private static AccountsManager accountsManager = mock(AccountsManager.class );
private static ProfilesManager profilesManager = mock(ProfilesManager.class);
@ -82,30 +86,30 @@ public class ProfileControllerTest {
private Account profileAccount;
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
@ClassRule
public static final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
@Before
public void setup() throws Exception {
@BeforeEach
void setup() throws Exception {
reset(s3client);
AccountsHelper.setupMockUpdate(accountsManager);
dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
@ -161,8 +165,13 @@ public class ProfileControllerTest {
clearInvocations(profilesManager);
}
@AfterEach
void teardown() {
reset(accountsManager);
}
@Test
public void testProfileGetByUuid() throws RateLimitExceededException {
void testProfileGetByUuid() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
@ -180,7 +189,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByNumber() throws RateLimitExceededException {
void testProfileGetByNumber() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@ -201,7 +210,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByUsername() throws RateLimitExceededException {
void testProfileGetByUsername() throws RateLimitExceededException {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
@ -220,7 +229,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetUnauthorized() {
void testProfileGetUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@ -230,7 +239,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileGetByUsernameUnauthorized() {
void testProfileGetByUsernameUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
@ -241,7 +250,7 @@ public class ProfileControllerTest {
@Test
public void testProfileGetByUsernameNotFound() throws RateLimitExceededException {
void testProfileGetByUsernameNotFound() throws RateLimitExceededException {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkillerzzzzz")
.request()
@ -256,7 +265,7 @@ public class ProfileControllerTest {
@Test
public void testProfileGetDisabled() {
void testProfileGetDisabled() {
Response response = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO)
.request()
@ -267,7 +276,7 @@ public class ProfileControllerTest {
}
@Test
public void testProfileCapabilities() {
void testProfileCapabilities() {
Profile profile= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_NUMBER)
.request()
@ -293,7 +302,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileNameDeprecated() {
void testSetProfileNameDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012")
.request()
@ -302,11 +311,11 @@ public class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager, times(1)).update(any(Account.class));
verify(accountsManager, times(1)).update(any(Account.class), any());
}
@Test
public void testSetProfileNameExtendedDeprecated() {
void testSetProfileNameExtendedDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678")
.request()
@ -315,11 +324,11 @@ public class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager, times(1)).update(any(Account.class));
verify(accountsManager, times(1)).update(any(Account.class), any());
}
@Test
public void testSetProfileNameWrongSizeDeprecated() {
void testSetProfileNameWrongSizeDeprecated() {
Response response = resources.getJerseyTest()
.target("/v1/profile/name/1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890")
.request()
@ -333,7 +342,7 @@ public class ProfileControllerTest {
/////
@Test
public void testSetProfileWantAvatarUpload() throws InvalidInputException {
void testSetProfileWantAvatarUpload() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
ProfileAvatarUploadAttributes uploadAttributes = resources.getJerseyTest()
@ -358,7 +367,7 @@ public class ProfileControllerTest {
assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); }
@Test
public void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException {
void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
Response response = resources.getJerseyTest()
@ -372,7 +381,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileWithoutAvatarUpload() throws InvalidInputException {
void testSetProfileWithoutAvatarUpload() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@ -406,7 +415,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException {
void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO);
ProfileAvatarUploadAttributes uploadAttributes= resources.getJerseyTest()
@ -430,7 +439,7 @@ public class ProfileControllerTest {
assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); }
@Test
public void testSetProfileExtendedName() throws InvalidInputException {
void testSetProfileExtendedName() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO);
final String name = RandomStringUtils.randomAlphabetic(380);
@ -456,7 +465,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileEmojiAndBioText() throws InvalidInputException {
void testSetProfileEmojiAndBioText() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@ -495,7 +504,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfilePaymentAddress() throws InvalidInputException {
void testSetProfilePaymentAddress() throws InvalidInputException {
when(dynamicPaymentsConfiguration.getAllowedCountryCodes())
.thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO)));
@ -536,7 +545,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException {
void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException {
ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID);
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
@ -557,7 +566,7 @@ public class ProfileControllerTest {
}
@Test
public void testGetProfileByVersion() throws RateLimitExceededException {
void testGetProfileByVersion() throws RateLimitExceededException {
Profile profile = resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO + "/validversion")
.request()
@ -582,7 +591,7 @@ public class ProfileControllerTest {
}
@Test
public void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException {
void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException {
when(dynamicPaymentsConfiguration.getAllowedCountryCodes())
.thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO)));
@ -606,7 +615,7 @@ public class ProfileControllerTest {
}
@Test
public void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() {
void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() {
when(profilesManager.get(AuthHelper.VALID_UUID_TWO, "validversion")).thenReturn(
Optional.of(new VersionedProfile(null, null, null, null, null, "paymentaddress", null)));
Profile profile = resources.getJerseyTest()

View File

@ -23,6 +23,7 @@ import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService;
import org.whispersystems.textsecuregcm.util.Util;
@ -40,6 +41,8 @@ public class GCMSenderTest {
when(successResult.hasCanonicalRegistrationId()).thenReturn(false);
when(successResult.isSuccess()).thenReturn(true);
AccountsHelper.setupMockUpdate(accountsManager);
GcmMessage message = new GcmMessage("foo", "+12223334444", 1, GcmMessage.Type.NOTIFICATION, Optional.empty());
GCMSender gcmSender = new GCMSender(executorService, accountsManager, sender);
@ -65,6 +68,8 @@ public class GCMSenderTest {
Account destinationAccount = mock(Account.class);
Device destinationDevice = mock(Device.class );
AccountsHelper.setupMockUpdate(accountsManager);
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
@ -85,7 +90,7 @@ public class GCMSenderTest {
verify(sender, times(1)).send(any(Message.class));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any());
verify(destinationDevice, times(1)).setUninstalledFeedbackTimestamp(eq(Util.todayInMillis()));
}
@ -107,6 +112,8 @@ public class GCMSenderTest {
when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount));
when(destinationDevice.getGcmId()).thenReturn(gcmId);
AccountsHelper.setupMockUpdate(accountsManager);
when(canonicalResult.isInvalidRegistrationId()).thenReturn(false);
when(canonicalResult.isUnregistered()).thenReturn(false);
when(canonicalResult.hasCanonicalRegistrationId()).thenReturn(true);
@ -124,7 +131,7 @@ public class GCMSenderTest {
verify(sender, times(1)).send(any(Message.class));
verify(accountsManager, times(1)).get(eq(destinationNumber));
verify(accountsManager, times(1)).update(eq(destinationAccount));
verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any());
verify(destinationDevice, times(1)).setGcmId(eq(canonicalId));
}

View File

@ -6,8 +6,10 @@
package org.whispersystems.textsecuregcm.tests.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -17,26 +19,26 @@ import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
public class AccountTest {
class AccountTest {
private final Device oldMasterDevice = mock(Device.class);
private final Device recentMasterDevice = mock(Device.class);
private final Device agingSecondaryDevice = mock(Device.class);
private final Device oldMasterDevice = mock(Device.class);
private final Device recentMasterDevice = mock(Device.class);
private final Device agingSecondaryDevice = mock(Device.class);
private final Device recentSecondaryDevice = mock(Device.class);
private final Device oldSecondaryDevice = mock(Device.class);
private final Device oldSecondaryDevice = mock(Device.class);
private final Device gv2CapableDevice = mock(Device.class);
private final Device gv2IncapableDevice = mock(Device.class);
private final Device gv2CapableDevice = mock(Device.class);
private final Device gv2IncapableDevice = mock(Device.class);
private final Device gv2IncapableExpiredDevice = mock(Device.class);
private final Device gv1MigrationCapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableDevice = mock(Device.class);
private final Device gv1MigrationIncapableExpiredDevice = mock(Device.class);
private final Device senderKeyCapableDevice = mock(Device.class);
@ -47,8 +49,8 @@ public class AccountTest {
private final Device announcementGroupIncapableDevice = mock(Device.class);
private final Device announcementGroupIncapableExpiredDevice = mock(Device.class);
@Before
public void setup() {
@BeforeEach
void setup() {
when(oldMasterDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldMasterDevice.isEnabled()).thenReturn(true);
when(oldMasterDevice.getId()).thenReturn(Device.MASTER_ID);
@ -119,9 +121,9 @@ public class AccountTest {
}
@Test
public void testIsEnabled() {
final Device enabledMasterDevice = mock(Device.class);
final Device enabledLinkedDevice = mock(Device.class);
void testIsEnabled() {
final Device enabledMasterDevice = mock(Device.class);
final Device enabledLinkedDevice = mock(Device.class);
final Device disabledMasterDevice = mock(Device.class);
final Device disabledLinkedDevice = mock(Device.class);
@ -144,7 +146,7 @@ public class AccountTest {
}
@Test
public void testCapabilities() {
void testCapabilities() {
Account uuidCapable = new Account("+14152222222", UUID.randomUUID(), new HashSet<Device>() {{
add(gv2CapableDevice);
}}, "1234".getBytes());
@ -165,13 +167,13 @@ public class AccountTest {
}
@Test
public void testIsTransferSupported() {
final Device transferCapableMasterDevice = mock(Device.class);
final Device nonTransferCapableMasterDevice = mock(Device.class);
final Device transferCapableLinkedDevice = mock(Device.class);
void testIsTransferSupported() {
final Device transferCapableMasterDevice = mock(Device.class);
final Device nonTransferCapableMasterDevice = mock(Device.class);
final Device transferCapableLinkedDevice = mock(Device.class);
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
when(transferCapableMasterDevice.getId()).thenReturn(1L);
when(transferCapableMasterDevice.isMaster()).thenReturn(true);
@ -213,10 +215,12 @@ public class AccountTest {
}
@Test
public void testDiscoverableByPhoneNumber() {
final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice), "1234".getBytes());
void testDiscoverableByPhoneNumber() {
final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice),
"1234".getBytes());
assertTrue("Freshly-loaded legacy accounts should be discoverable by phone number.", account.isDiscoverableByPhoneNumber());
assertTrue(account.isDiscoverableByPhoneNumber(),
"Freshly-loaded legacy accounts should be discoverable by phone number.");
account.setDiscoverableByPhoneNumber(false);
assertFalse(account.isDiscoverableByPhoneNumber());
@ -226,21 +230,29 @@ public class AccountTest {
}
@Test
public void isGroupsV2Supported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
void isGroupsV2Supported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
}
@Test
public void isGv1MigrationSupported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
void isGv1MigrationSupported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertFalse(
new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertTrue(new Account("+18005551234", UUID.randomUUID(),
Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8))
.isGv1MigrationSupported());
}
@Test
public void isSenderKeySupported() {
void isSenderKeySupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice, senderKeyIncapableDevice),
@ -251,7 +263,7 @@ public class AccountTest {
}
@Test
public void isAnnouncementGroupSupported() {
void isAnnouncementGroupSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(),
Set.of(announcementGroupCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
@ -262,4 +274,16 @@ public class AccountTest {
Set.of(announcementGroupCapableDevice, announcementGroupIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
}
@Test
void stale() {
final Account account = new Account("+14151234567", UUID.randomUUID(), Collections.emptySet(), new byte[0]);
assertDoesNotThrow(account::getNumber);
account.markStale();
assertThrows(AssertionError.class, account::getNumber);
assertDoesNotThrow(account::getUuid);
}
}

View File

@ -6,11 +6,14 @@
package org.whispersystems.textsecuregcm.tests.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -22,13 +25,17 @@ import static org.mockito.Mockito.when;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException;
import java.util.HashSet;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Consumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
@ -41,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsDynamoDb;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@ -48,14 +56,22 @@ import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
class AccountsManagerTest {
private DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> {
// it is implicit in the update() contract is that a successful call will
// result in an incremented version
final Account updatedAccount = answer.getArgument(0, Account.class);
updatedAccount.setVersion(updatedAccount.getVersion() + 1);
return null;
};
@BeforeEach
void setup() {
@ -326,7 +342,7 @@ class AccountsManagerTest {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUpdate_dynamoDbMigration(boolean dynamoEnabled) {
void testUpdate_dynamoDbMigration(boolean dynamoEnabled) throws IOException {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
@ -335,35 +351,59 @@ class AccountsManagerTest {
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(dynamoEnabled);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
// database fetches should always return new instances
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any(Account.class));
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts,
directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(0, account.getDynamoDbMigrationVersion());
Account updatedAccount = accountsManager.update(account, a -> a.setProfileName("name"));
accountsManager.update(account);
assertThrows(AssertionError.class, account::getProfileName, "Account passed to update() should be stale");
assertEquals(1, account.getDynamoDbMigrationVersion());
assertNotSame(updatedAccount, account);
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).update(account);
if (dynamoEnabled) {
ArgumentCaptor<Account> argumentCaptor = ArgumentCaptor.forClass(Account.class);
verify(accountsDynamoDb, times(1)).update(argumentCaptor.capture());
assertEquals(uuid, argumentCaptor.getValue().getUuid());
} else {
verify(accountsDynamoDb, never()).update(any());
}
verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
ArgumentCaptor<String> redisSetArgumentCapture = ArgumentCaptor.forClass(String.class);
verify(commands, times(4)).set(anyString(), redisSetArgumentCapture.capture());
Account firstAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(1), Account.class);
Account secondAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(3), Account.class);
// uuid is @JsonIgnore, so we need to set it for compareAccounts to work
firstAccountCached.setUuid(uuid);
secondAccountCached.setUuid(uuid);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(firstAccountCached), Optional.of(secondAccountCached)));
}
@Test
void testUpdate_dynamoConditionFailed() {
void testUpdate_dynamoMissing() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
@ -382,25 +422,158 @@ class AccountsManagerTest {
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
doThrow(ConditionalCheckFailedException.class).when(accountsDynamoDb).update(any(Account.class));
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty());
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any());
doAnswer(ACCOUNT_UPDATE_ANSWER).when(accountsDynamoDb).update(any());
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts,
directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(0, account.getDynamoDbMigrationVersion());
accountsManager.update(account);
assertEquals(1, account.getDynamoDbMigrationVersion());
Account updatedAccount = accountsManager.update(account, a -> {});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, times(1)).update(account);
verify(accountsDynamoDb, times(1)).create(account);
verify(accountsDynamoDb, never()).update(account);
verify(accountsDynamoDb, times(1)).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
assertEquals(1, updatedAccount.getVersion());
}
@Test
void testUpdate_optimisticLockingFailure() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any());
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accountsDynamoDb).update(any());
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
account = accountsManager.update(account, a -> a.setProfileName("name"));
assertEquals(1, account.getVersion());
assertEquals("name", account.getProfileName());
verify(accounts, times(1)).get(uuid);
verify(accounts, times(2)).update(any());
verifyNoMoreInteractions(accounts);
// dynamo has an extra get() because the account is fetched before every update
verify(accountsDynamoDb, times(2)).get(uuid);
verify(accountsDynamoDb, times(2)).update(any());
verifyNoMoreInteractions(accountsDynamoDb);
}
@Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
enableDynamo(true);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accountsDynamoDb.create(any())).thenThrow(ContestedOptimisticLockException.class);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
accountsManager.update(account, a -> {});
verify(accounts, times(1)).update(account);
verifyNoMoreInteractions(accounts);
verify(accountsDynamoDb, times(1)).get(uuid);
verifyNoMoreInteractions(accountsDynamoDb);
}
@Test
void testUpdateDevice() throws Exception {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
Accounts accounts = mock(Accounts.class);
AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class);
DeletedAccounts deletedAccounts = mock(DeletedAccounts.class);
DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class);
MessagesManager messagesManager = mock(MessagesManager.class);
UsernamesManager usernamesManager = mock(UsernamesManager.class);
ProfilesManager profilesManager = mock(ProfilesManager.class);
SecureBackupClient secureBackupClient = mock(SecureBackupClient.class);
SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.empty(), Optional.empty()));
final UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]);
when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16])));
assertTrue(account.getDevices().isEmpty());
Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(new SignedPreKey(1L, "key", "signature"));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@SuppressWarnings("unchecked") Consumer<Device> deviceUpdater = mock(Consumer.class);
@SuppressWarnings("unchecked") Consumer<Device> unknownDeviceUpdater = mock(Consumer.class);
account = accountsManager.updateDevice(account, deviceId, deviceUpdater);
account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName"));
assertEquals("deviceName", account.getDevice(deviceId).get().getName());
verify(deviceUpdater, times(1)).accept(any(Device.class));
accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater);
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testCompareAccounts() throws Exception {
RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
@ -479,9 +652,11 @@ class AccountsManagerTest {
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
a1.setDynamoDbMigrationVersion(1);
a1.setVersion(1);
assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
assertEquals(Optional.of("version"), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2)));
a2.setVersion(1);
a2.setProfileName("name");

View File

@ -167,6 +167,12 @@ public class AccountsTest {
accounts.update(account);
account.setProfileName("profileName");
accounts.update(account);
assertThat(account.getVersion()).isEqualTo(2);
Optional<Account> retrieved = accounts.get("+14151112222");
assertThat(retrieved.isPresent()).isTrue();
@ -359,6 +365,7 @@ public class AccountsTest {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(result.getVersion()).isEqualTo(expecting.getVersion());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {

View File

@ -5,16 +5,16 @@
package org.whispersystems.textsecuregcm.tests.storage;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.util.Util;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.List;
@ -22,11 +22,20 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.Util;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
public class PushFeedbackProcessorTest {
class PushFeedbackProcessorTest {
private AccountsManager accountsManager = mock(AccountsManager.class);
private DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
@ -46,8 +55,10 @@ public class PushFeedbackProcessorTest {
private Device stillActiveDevice = mock(Device.class);
private Device undiscoverableDevice = mock(Device.class);
@Before
public void setup() {
@BeforeEach
void setup() {
AccountsHelper.setupMockUpdate(accountsManager);
when(uninstalledDevice.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2));
when(uninstalledDevice.getLastSeen()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2));
when(uninstalledDeviceTwo.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(3));
@ -85,7 +96,7 @@ public class PushFeedbackProcessorTest {
@Test
public void testEmpty() throws AccountDatabaseCrawlerRestartException {
void testEmpty() throws AccountDatabaseCrawlerRestartException {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), Collections.emptyList());
@ -94,7 +105,7 @@ public class PushFeedbackProcessorTest {
}
@Test
public void testUpdate() throws AccountDatabaseCrawlerRestartException {
void testUpdate() throws AccountDatabaseCrawlerRestartException {
PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue);
processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), List.of(uninstalledAccount, mixedAccount, stillActiveAccount, freshAccount, cleanAccount, undiscoverableAccount));
@ -102,7 +113,7 @@ public class PushFeedbackProcessorTest {
verify(uninstalledDevice).setGcmId(isNull());
verify(uninstalledDevice).setFetchesMessages(eq(false));
verify(accountsManager).update(eq(uninstalledAccount));
verify(accountsManager).update(eq(uninstalledAccount), any());
verify(uninstalledDeviceTwo).setApnId(isNull());
verify(uninstalledDeviceTwo).setGcmId(isNull());
@ -112,33 +123,35 @@ public class PushFeedbackProcessorTest {
verify(installedDevice, never()).setGcmId(any());
verify(installedDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager).update(eq(mixedAccount));
verify(accountsManager).update(eq(mixedAccount), any());
verify(recentUninstalledDevice, never()).setApnId(any());
verify(recentUninstalledDevice, never()).setGcmId(any());
verify(recentUninstalledDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager, never()).update(eq(freshAccount));
verify(accountsManager, never()).update(eq(freshAccount), any());
verify(installedDeviceTwo, never()).setApnId(any());
verify(installedDeviceTwo, never()).setGcmId(any());
verify(installedDeviceTwo, never()).setFetchesMessages(anyBoolean());
verify(accountsManager, never()).update(eq(cleanAccount));
verify(accountsManager, never()).update(eq(cleanAccount), any());
verify(stillActiveDevice).setUninstalledFeedbackTimestamp(eq(0L));
verify(stillActiveDevice, never()).setApnId(any());
verify(stillActiveDevice, never()).setGcmId(any());
verify(stillActiveDevice, never()).setFetchesMessages(anyBoolean());
verify(accountsManager).update(eq(stillActiveAccount));
verify(accountsManager).update(eq(stillActiveAccount), any());
final ArgumentCaptor<List<Account>> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class);
verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture());
assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount)));
final List<UUID> refreshedUuids = refreshedAccountArgumentCaptor.getValue().stream()
.map(Account::getUuid)
.collect(Collectors.toList());
assertTrue(refreshedUuids.containsAll(List.of(undiscoverableAccount.getUuid(), uninstalledAccount.getUuid())));
}
}

View File

@ -0,0 +1,148 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockingDetails;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.function.Consumer;
import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper {
public static void setupMockUpdate(final AccountsManager mockAccountsManager) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
answer.getArgument(1, Consumer.class).accept(account);
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDevice(any(), anyLong(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final Long deviceId = answer.getArgument(1, Long.class);
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
return copyAndMarkStale(account);
});
}
private static Account copyAndMarkStale(Account account) throws IOException {
MockingDetails mockingDetails = mockingDetails(account);
final Account updatedAccount;
if (mockingDetails.isMock()) {
updatedAccount = mock(Account.class);
// its not possible to make `account` behave as if it were stale, because we use static mocks in AuthHelper
for (Stubbing stubbing : mockingDetails.getStubbings()) {
switch (stubbing.getInvocation().getMethod().getName()) {
case "getUuid": {
when(updatedAccount.getUuid()).thenAnswer(stubbing);
break;
}
case "getNumber": {
when(updatedAccount.getNumber()).thenAnswer(stubbing);
break;
}
case "getDevices": {
when(updatedAccount.getDevices())
.thenAnswer(stubbing);
break;
}
case "getDevice": {
when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0)))
.thenAnswer(stubbing);
break;
}
case "getMasterDevice": {
when(updatedAccount.getMasterDevice()).thenAnswer(stubbing);
break;
}
case "getAuthenticatedDevice": {
when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing);
break;
}
case "isEnabled": {
when(updatedAccount.isEnabled()).thenAnswer(stubbing);
break;
}
case "isDiscoverableByPhoneNumber": {
when(updatedAccount.isDiscoverableByPhoneNumber()).thenAnswer(stubbing);
break;
}
case "getNextDeviceId": {
when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
break;
}
case "isGroupsV2Supported": {
when(updatedAccount.isGroupsV2Supported()).thenAnswer(stubbing);
break;
}
case "isGv1MigrationSupported": {
when(updatedAccount.isGv1MigrationSupported()).thenAnswer(stubbing);
break;
}
case "isSenderKeySupported": {
when(updatedAccount.isSenderKeySupported()).thenAnswer(stubbing);
break;
}
case "isAnnouncementGroupSupported": {
when(updatedAccount.isAnnouncementGroupSupported()).thenAnswer(stubbing);
break;
}
case "getEnabledDeviceCount": {
when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing);
break;
}
case "getRelay": {
// TODO unused
when(updatedAccount.getRelay()).thenAnswer(stubbing);
break;
}
case "getRegistrationLock": {
when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
break;
}
case "getIdentityKey": {
when(updatedAccount.getIdentityKey()).thenAnswer(stubbing);
break;
}
default: {
throw new IllegalArgumentException(
"unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
}
}
}
} else {
final ObjectMapper mapper = SystemMapper.getMapper();
updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
updatedAccount.setNumber(account.getNumber());
account.markStale();
}
return updatedAccount;
}
public static Account eqUuid(Account value) {
return argThat(other -> other.getUuid().equals(value.getUuid()));
}
}