From 158d65c6a740c7ce7d67ad267f37e475619d6ea7 Mon Sep 17 00:00:00 2001 From: Chris Eager <79161849+eager-signal@users.noreply.github.com> Date: Wed, 7 Jul 2021 11:54:22 -0500 Subject: [PATCH] Add optimistic locking to account updates --- .../auth/BaseAccountAuthenticator.java | 4 +- .../CircuitBreakerConfiguration.java | 31 +- .../controllers/AccountController.java | 107 +++---- .../controllers/DeviceController.java | 16 +- .../controllers/KeysController.java | 10 +- .../controllers/ProfileController.java | 15 +- .../textsecuregcm/push/GCMSender.java | 18 +- .../textsecuregcm/storage/Account.java | 120 +++++++- .../textsecuregcm/storage/AccountStore.java | 2 +- .../textsecuregcm/storage/Accounts.java | 38 ++- .../storage/AccountsDynamoDb.java | 51 +++- .../storage/AccountsManager.java | 98 +++++-- .../ContestedOptimisticLockException.java | 13 + ...misticLockRetryLimitExceededException.java | 10 + .../storage/PushFeedbackProcessor.java | 69 +++-- .../storage/mappers/AccountRowMapper.java | 1 + service/src/main/resources/accountsdb.xml | 6 + .../storage/AccountsDynamoDbTest.java | 45 ++- ...ConcurrentModificationIntegrationTest.java | 274 ++++++++++++++++++ .../auth/BaseAccountAuthenticatorTest.java | 164 ++++++----- .../controllers/AccountControllerTest.java | 18 +- .../controllers/DeviceControllerTest.java | 6 +- .../tests/controllers/KeysControllerTest.java | 21 +- .../controllers/ProfileControllerTest.java | 109 +++---- .../tests/push/GCMSenderTest.java | 11 +- .../tests/storage/AccountTest.java | 98 ++++--- .../tests/storage/AccountsManagerTest.java | 221 ++++++++++++-- .../tests/storage/AccountsTest.java | 7 + .../storage/PushFeedbackProcessorTest.java | 65 +++-- .../tests/util/AccountsHelper.java | 148 ++++++++++ 30 files changed, 1397 insertions(+), 399 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/ContestedOptimisticLockException.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/OptimisticLockRetryLimitExceededException.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java index 4401b453d..ebf86279f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java @@ -9,7 +9,6 @@ import static com.codahale.metrics.MetricRegistry.name; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.auth.basic.BasicCredentials; -import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import java.time.Clock; @@ -118,8 +117,7 @@ public class BaseAccountAuthenticator { Metrics.summary(DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME, IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isMaster())) .record(Duration.ofMillis(todayInMillisWithOffset - device.getLastSeen()).toDays()); - device.setLastSeen(Util.todayInMillis(clock)); - accountsManager.update(account); + accountsManager.updateDevice(account, device.getId(), d -> d.setLastSeen(Util.todayInMillis(clock))); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java index 08557c4f2..fa6cbdf40 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/CircuitBreakerConfiguration.java @@ -7,15 +7,15 @@ package org.whispersystems.textsecuregcm.configuration; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; - +import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; -import java.time.Duration; - -import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig; - public class CircuitBreakerConfiguration { @JsonProperty @@ -39,6 +39,9 @@ public class CircuitBreakerConfiguration { @Min(1) private long waitDurationInOpenStateInSeconds = 10; + @JsonProperty + private List ignoredExceptions = Collections.emptyList(); + public int getFailureRateThreshold() { return failureRateThreshold; @@ -56,6 +59,18 @@ public class CircuitBreakerConfiguration { return waitDurationInOpenStateInSeconds; } + public List getIgnoredExceptions() { + return ignoredExceptions.stream() + .map(name -> { + try { + return Class.forName(name); + } catch (final ClassNotFoundException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + } + @VisibleForTesting public void setFailureRateThreshold(int failureRateThreshold) { this.failureRateThreshold = failureRateThreshold; @@ -76,9 +91,15 @@ public class CircuitBreakerConfiguration { this.waitDurationInOpenStateInSeconds = seconds; } + @VisibleForTesting + public void setIgnoredExceptions(final List ignoredExceptions) { + this.ignoredExceptions = ignoredExceptions; + } + public CircuitBreakerConfig toCircuitBreakerConfig() { return CircuitBreakerConfig.custom() .failureRateThreshold(getFailureRateThreshold()) + .ignoreExceptions(getIgnoredExceptions().toArray(new Class[0])) .ringBufferSizeInHalfOpenState(getRingBufferSizeInHalfOpenState()) .waitDurationInOpenState(Duration.ofSeconds(getWaitDurationInOpenStateInSeconds())) .ringBufferSizeInClosedState(getRingBufferSizeInClosedState()) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index b3497c51e..079c4df8a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -439,12 +439,12 @@ public class AccountController { return; } - device.setApnId(null); - device.setVoipApnId(null); - device.setGcmId(registrationId.getGcmRegistrationId()); - device.setFetchesMessages(false); - - accounts.update(account); + account = accounts.updateDevice(account, device.getId(), d -> { + d.setApnId(null); + d.setVoipApnId(null); + d.setGcmId(registrationId.getGcmRegistrationId()); + d.setFetchesMessages(false); + }); if (!wasAccountEnabled && account.isEnabled()) { directoryQueue.refreshRegisteredUser(account); @@ -457,11 +457,12 @@ public class AccountController { public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { Account account = disabledPermittedAccount.getAccount(); Device device = account.getAuthenticatedDevice().get(); - device.setGcmId(null); - device.setFetchesMessages(false); - device.setUserAgent("OWA"); - accounts.update(account); + account = accounts.updateDevice(account, device.getId(), d -> { + d.setGcmId(null); + d.setFetchesMessages(false); + d.setUserAgent("OWA"); + }); directoryQueue.refreshRegisteredUser(account); } @@ -474,11 +475,12 @@ public class AccountController { Device device = account.getAuthenticatedDevice().get(); boolean wasAccountEnabled = account.isEnabled(); - device.setApnId(registrationId.getApnRegistrationId()); - device.setVoipApnId(registrationId.getVoipRegistrationId()); - device.setGcmId(null); - device.setFetchesMessages(false); - accounts.update(account); + account = accounts.updateDevice(account, device.getId(), d -> { + d.setApnId(registrationId.getApnRegistrationId()); + d.setVoipApnId(registrationId.getVoipRegistrationId()); + d.setGcmId(null); + d.setFetchesMessages(false); + }); if (!wasAccountEnabled && account.isEnabled()) { directoryQueue.refreshRegisteredUser(account); @@ -491,15 +493,16 @@ public class AccountController { public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { Account account = disabledPermittedAccount.getAccount(); Device device = account.getAuthenticatedDevice().get(); - device.setApnId(null); - device.setFetchesMessages(false); - if (device.getId() == 1) { - device.setUserAgent("OWI"); - } else { - device.setUserAgent("OWP"); - } - accounts.update(account); + accounts.updateDevice(account, device.getId(), d -> { + d.setApnId(null); + d.setFetchesMessages(false); + if (d.getId() == 1) { + d.setUserAgent("OWI"); + } else { + d.setUserAgent("OWP"); + } + }); directoryQueue.refreshRegisteredUser(account); } @@ -509,18 +512,18 @@ public class AccountController { @Path("/registration_lock") public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) { AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock()); - account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()); - account.setPin(null); - accounts.update(account); + accounts.update(account, a -> { + a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()); + a.setPin(null); + }); } @Timed @DELETE @Path("/registration_lock") public void removeRegistrationLock(@Auth Account account) { - account.setRegistrationLock(null, null); - accounts.update(account); + accounts.update(account, a -> a.setRegistrationLock(null, null)); } @Timed @@ -531,21 +534,21 @@ public class AccountController { // TODO Remove once PIN-based reglocks have been deprecated logger.info("PIN set by User-Agent: {}", userAgent); - account.setPin(accountLock.getPin()); - account.setRegistrationLock(null, null); - - accounts.update(account); + accounts.update(account, a -> { + a.setPin(accountLock.getPin()); + a.setRegistrationLock(null, null); + }); } @Timed @DELETE @Path("/pin/") + public void removePin(@Auth Account account, @HeaderParam("User-Agent") String userAgent) { // TODO Remove once PIN-based reglocks have been deprecated logger.info("PIN removed by User-Agent: {}", userAgent); - account.setPin(null); - accounts.update(account); + accounts.update(account, a -> a.setPin(null)); } @Timed @@ -553,8 +556,8 @@ public class AccountController { @Path("/name/") public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) { Account account = disabledPermittedAccount.getAccount(); - account.getAuthenticatedDevice().get().setName(deviceName.getDeviceName()); - accounts.update(account); + Device device = account.getAuthenticatedDevice().get(); + accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName())); } @Timed @@ -572,25 +575,29 @@ public class AccountController { @Valid AccountAttributes attributes) { Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + long deviceId = account.getAuthenticatedDevice().get().getId(); - device.setFetchesMessages(attributes.getFetchesMessages()); - device.setName(attributes.getName()); - device.setLastSeen(Util.todayInMillis()); - device.setCapabilities(attributes.getCapabilities()); - device.setRegistrationId(attributes.getRegistrationId()); - device.setUserAgent(userAgent); + account = accounts.update(account, a-> { - setAccountRegistrationLockFromAttributes(account, attributes); + a.getDevice(deviceId).ifPresent(d -> { + d.setFetchesMessages(attributes.getFetchesMessages()); + d.setName(attributes.getName()); + d.setLastSeen(Util.todayInMillis()); + d.setCapabilities(attributes.getCapabilities()); + d.setRegistrationId(attributes.getRegistrationId()); + d.setUserAgent(userAgent); + }); + + setAccountRegistrationLockFromAttributes(a, attributes); + + a.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey()); + a.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess()); + a.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber()); + + }); final boolean hasDiscoverabilityChange = (account.isDiscoverableByPhoneNumber() != attributes.isDiscoverableByPhoneNumber()); - account.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey()); - account.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess()); - account.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber()); - - accounts.update(account); - if (hasDiscoverabilityChange) { directoryQueue.refreshRegisteredUser(account); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index e5b3eb2c0..a7b76c165 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -100,8 +100,7 @@ public class DeviceController { } messages.clear(account.getUuid(), deviceId); - account.removeDevice(deviceId); - accounts.update(account); + account = accounts.update(account, a -> a.removeDevice(deviceId)); directoryQueue.refreshRegisteredUser(account); // ensure any messages that came in after the first clear() are also removed messages.clear(account.getUuid(), deviceId); @@ -192,15 +191,16 @@ public class DeviceController { device.setName(accountAttributes.getName()); device.setAuthenticationCredentials(new AuthenticationCredentials(password)); device.setFetchesMessages(accountAttributes.getFetchesMessages()); - device.setId(account.get().getNextDeviceId()); device.setRegistrationId(accountAttributes.getRegistrationId()); device.setLastSeen(Util.todayInMillis()); device.setCreated(System.currentTimeMillis()); device.setCapabilities(accountAttributes.getCapabilities()); - account.get().addDevice(device); - messages.clear(account.get().getUuid(), device.getId()); - accounts.update(account.get()); + accounts.update(account.get(), a -> { + device.setId(account.get().getNextDeviceId()); + messages.clear(account.get().getUuid(), device.getId()); + a.addDevice(device); + });; pendingDevices.remove(number); @@ -224,8 +224,8 @@ public class DeviceController { @Path("/capabilities") public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) { assert(account.getAuthenticatedDevice().isPresent()); - account.getAuthenticatedDevice().get().setCapabilities(capabilities); - accounts.update(account); + final long deviceId = account.getAuthenticatedDevice().get().getId(); + accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities)); } @VisibleForTesting protected VerificationCode generateVerificationCode() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 2015c7150..04030c2a0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -104,17 +104,18 @@ public class KeysController { boolean updateAccount = false; if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) { - device.setSignedPreKey(preKeys.getSignedPreKey()); updateAccount = true; } if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) { - account.setIdentityKey(preKeys.getIdentityKey()); updateAccount = true; } if (updateAccount) { - accounts.update(account); + account = accounts.update(account, a -> { + a.getDevice(device.getId()).ifPresent(d -> d.setSignedPreKey(preKeys.getSignedPreKey())); + a.setIdentityKey(preKeys.getIdentityKey()); + }); if (!wasAccountEnabled && account.isEnabled()) { directoryQueue.refreshRegisteredUser(account); @@ -200,8 +201,7 @@ public class KeysController { Device device = account.getAuthenticatedDevice().get(); boolean wasAccountEnabled = account.isEnabled(); - device.setSignedPreKey(signedPreKey); - accounts.update(account); + account = accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey)); if (!wasAccountEnabled && account.isEnabled()) { directoryQueue.refreshRegisteredUser(account); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index f18de2fb6..beb15c78b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -156,10 +156,11 @@ public class ProfileController { response = Optional.of(generateAvatarUploadForm(avatar)); } - account.setProfileName(request.getName()); - account.setAvatar(avatar); - account.setCurrentProfileVersion(request.getVersion()); - accountsManager.update(account); + accountsManager.update(account, a -> { + a.setProfileName(request.getName()); + a.setAvatar(avatar); + a.setCurrentProfileVersion(request.getVersion()); + }); if (response.isPresent()) return Response.ok(response).build(); else return Response.ok().build(); @@ -317,8 +318,7 @@ public class ProfileController { @Produces(MediaType.APPLICATION_JSON) @Path("/name/{name}") public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional name) { - account.setProfileName(name.orElse(null)); - accountsManager.update(account); + accountsManager.update(account, a -> a.setProfileName(name.orElse(null))); } @Deprecated @@ -382,8 +382,7 @@ public class ProfileController { .build()); } - account.setAvatar(objectName); - accountsManager.update(account); + accountsManager.update(account, a -> a.setAvatar(objectName)); return profileAvatarUploadAttributes; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java index 6d0876376..961d9e3c2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/GCMSender.java @@ -110,8 +110,8 @@ public class GCMSender { Device device = account.get().getDevice(message.getDeviceId()).get(); if (device.getUninstalledFeedbackTimestamp() == 0) { - device.setUninstalledFeedbackTimestamp(Util.todayInMillis()); - accountsManager.update(account.get()); + accountsManager.updateDevice(account.get(), message.getDeviceId(), d -> + d.setUninstalledFeedbackTimestamp(Util.todayInMillis())); } } @@ -122,15 +122,11 @@ public class GCMSender { logger.warn(String.format("Actually received 'CanonicalRegistrationId' ::: (canonical=%s), (original=%s)", result.getCanonicalRegistrationId(), message.getGcmId())); - Optional account = getAccountForEvent(message); - - if (account.isPresent()) { - //noinspection OptionalGetWithoutIsPresent - Device device = account.get().getDevice(message.getDeviceId()).get(); - device.setGcmId(result.getCanonicalRegistrationId()); - - accountsManager.update(account.get()); - } + getAccountForEvent(message).ifPresent(account -> + accountsManager.updateDevice( + account, + message.getDeviceId(), + d -> d.setGcmId(result.getCanonicalRegistrationId()))); canonical.mark(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index fdebdf5ad..a437fe5e5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -14,11 +14,16 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import javax.security.auth.Subject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; public class Account implements Principal { + @JsonIgnore + private static final Logger logger = LoggerFactory.getLogger(Account.class); + @JsonIgnore private UUID uuid; @@ -58,12 +63,15 @@ public class Account implements Principal { @JsonProperty("inCds") private boolean discoverableByPhoneNumber = true; - @JsonProperty("_ddbV") - private int dynamoDbMigrationVersion; - @JsonIgnore private Device authenticatedDevice; + @JsonProperty + private int version; + + @JsonIgnore + private boolean stale; + public Account() {} @VisibleForTesting @@ -75,47 +83,68 @@ public class Account implements Principal { } public Optional getAuthenticatedDevice() { + requireNotStale(); + return Optional.ofNullable(authenticatedDevice); } public void setAuthenticatedDevice(Device device) { + requireNotStale(); + this.authenticatedDevice = device; } public UUID getUuid() { + // this is the one method that may be called on a stale account return uuid; } public void setUuid(UUID uuid) { + requireNotStale(); + this.uuid = uuid; } public void setNumber(String number) { + requireNotStale(); + this.number = number; } public String getNumber() { + requireNotStale(); + return number; } public void addDevice(Device device) { + requireNotStale(); + this.devices.remove(device); this.devices.add(device); } public void removeDevice(long deviceId) { + requireNotStale(); + this.devices.remove(new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "NA", 0, null)); } public Set getDevices() { + requireNotStale(); + return devices; } public Optional getMasterDevice() { + requireNotStale(); + return getDevice(Device.MASTER_ID); } public Optional getDevice(long deviceId) { + requireNotStale(); + for (Device device : devices) { if (device.getId() == deviceId) { return Optional.of(device); @@ -126,42 +155,58 @@ public class Account implements Principal { } public boolean isGroupsV2Supported() { + requireNotStale(); + return devices.stream() .filter(Device::isEnabled) .allMatch(Device::isGroupsV2Supported); } public boolean isStorageSupported() { + requireNotStale(); + return devices.stream().anyMatch(device -> device.getCapabilities() != null && device.getCapabilities().isStorage()); } public boolean isTransferSupported() { + requireNotStale(); + return getMasterDevice().map(Device::getCapabilities).map(Device.DeviceCapabilities::isTransfer).orElse(false); } public boolean isGv1MigrationSupported() { + requireNotStale(); + return devices.stream() .filter(Device::isEnabled) .allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isGv1Migration()); } public boolean isSenderKeySupported() { + requireNotStale(); + return devices.stream() .filter(Device::isEnabled) .allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isSenderKey()); } public boolean isAnnouncementGroupSupported() { + requireNotStale(); + return devices.stream() .filter(Device::isEnabled) .allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isAnnouncementGroup()); } public boolean isEnabled() { + requireNotStale(); + return getMasterDevice().map(Device::isEnabled).orElse(false); } public long getNextDeviceId() { + requireNotStale(); + long highestDevice = Device.MASTER_ID; for (Device device : devices) { @@ -176,6 +221,8 @@ public class Account implements Principal { } public int getEnabledDeviceCount() { + requireNotStale(); + int count = 0; for (Device device : devices) { @@ -186,22 +233,32 @@ public class Account implements Principal { } public boolean isRateLimited() { + requireNotStale(); + return true; } public Optional getRelay() { + requireNotStale(); + return Optional.empty(); } public void setIdentityKey(String identityKey) { + requireNotStale(); + this.identityKey = identityKey; } public String getIdentityKey() { + requireNotStale(); + return identityKey; } public long getLastSeen() { + requireNotStale(); + long lastSeen = 0; for (Device device : devices) { @@ -214,78 +271,127 @@ public class Account implements Principal { } public Optional getCurrentProfileVersion() { + requireNotStale(); + return Optional.ofNullable(currentProfileVersion); } public void setCurrentProfileVersion(String currentProfileVersion) { + requireNotStale(); + this.currentProfileVersion = currentProfileVersion; } public String getProfileName() { + requireNotStale(); + return name; } public void setProfileName(String name) { + requireNotStale(); + this.name = name; } public String getAvatar() { + requireNotStale(); + return avatar; } public void setAvatar(String avatar) { + requireNotStale(); + this.avatar = avatar; } public void setPin(String pin) { + requireNotStale(); + this.pin = pin; } public void setRegistrationLock(String registrationLock, String registrationLockSalt) { + requireNotStale(); + this.registrationLock = registrationLock; this.registrationLockSalt = registrationLockSalt; } public StoredRegistrationLock getRegistrationLock() { + requireNotStale(); + return new StoredRegistrationLock(Optional.ofNullable(registrationLock), Optional.ofNullable(registrationLockSalt), Optional.ofNullable(pin), getLastSeen()); } public Optional getUnidentifiedAccessKey() { + requireNotStale(); + return Optional.ofNullable(unidentifiedAccessKey); } public void setUnidentifiedAccessKey(byte[] unidentifiedAccessKey) { + requireNotStale(); + this.unidentifiedAccessKey = unidentifiedAccessKey; } public boolean isUnrestrictedUnidentifiedAccess() { + requireNotStale(); + return unrestrictedUnidentifiedAccess; } public void setUnrestrictedUnidentifiedAccess(boolean unrestrictedUnidentifiedAccess) { + requireNotStale(); + this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess; } public boolean isFor(AmbiguousIdentifier identifier) { + requireNotStale(); + if (identifier.hasUuid()) return identifier.getUuid().equals(uuid); else if (identifier.hasNumber()) return identifier.getNumber().equals(number); else throw new AssertionError(); } public boolean isDiscoverableByPhoneNumber() { + requireNotStale(); + return this.discoverableByPhoneNumber; } public void setDiscoverableByPhoneNumber(final boolean discoverableByPhoneNumber) { + requireNotStale(); + this.discoverableByPhoneNumber = discoverableByPhoneNumber; } - public int getDynamoDbMigrationVersion() { - return dynamoDbMigrationVersion; + public int getVersion() { + requireNotStale(); + + return version; } - public void setDynamoDbMigrationVersion(int dynamoDbMigrationVersion) { - this.dynamoDbMigrationVersion = dynamoDbMigrationVersion; + public void setVersion(int version) { + requireNotStale(); + + this.version = version; + } + + public void markStale() { + stale = true; + } + + private void requireNotStale() { + assert !stale; + + //noinspection ConstantConditions + if (stale) { + logger.error("Accessor called on stale account", new RuntimeException()); + } } // Principal implementation diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountStore.java index 8b9e733c2..065c97bf6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountStore.java @@ -7,7 +7,7 @@ public interface AccountStore { boolean create(Account account); - void update(Account account); + void update(Account account) throws ContestedOptimisticLockException; Optional get(String number); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index a8a8a1fa8..ebbcf8ec5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -13,6 +13,7 @@ import com.codahale.metrics.Timer.Context; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import org.jdbi.v3.core.transaction.TransactionIsolationLevel; @@ -22,10 +23,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; public class Accounts implements AccountStore { - public static final String ID = "id"; - public static final String UID = "uuid"; + public static final String ID = "id"; + public static final String UID = "uuid"; public static final String NUMBER = "number"; - public static final String DATA = "data"; + public static final String DATA = "data"; + public static final String VERSION = "version"; private static final ObjectMapper mapper = SystemMapper.getMapper(); @@ -50,15 +52,19 @@ public class Accounts implements AccountStore { public boolean create(Account account) { return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> { try (Timer.Context ignored = createTimer.time()) { - UUID uuid = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET data = EXCLUDED.data RETURNING uuid") - .bind("number", account.getNumber()) - .bind("uuid", account.getUuid()) - .bind("data", mapper.writeValueAsString(account)) - .mapTo(UUID.class) - .findOnly(); + final Map resultMap = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET " + DATA + " = EXCLUDED.data, " + VERSION + " = accounts.version + 1 RETURNING uuid, version") + .bind("number", account.getNumber()) + .bind("uuid", account.getUuid()) + .bind("data", mapper.writeValueAsString(account)) + .mapToMap() + .findOnly(); + + final UUID uuid = (UUID) resultMap.get(UID); + final int version = (int) resultMap.get(VERSION); boolean isNew = uuid.equals(account.getUuid()); account.setUuid(uuid); + account.setVersion(version); return isNew; } catch (JsonProcessingException e) { throw new IllegalArgumentException(e); @@ -67,13 +73,23 @@ public class Accounts implements AccountStore { } @Override - public void update(Account account) { + public void update(Account account) throws ContestedOptimisticLockException { database.use(jdbi -> jdbi.useHandle(handle -> { try (Timer.Context ignored = updateTimer.time()) { - handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + UID + " = :uuid") + final int newVersion = account.getVersion() + 1; + int rowsModified = handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json), " + VERSION + " = :newVersion WHERE " + UID + " = :uuid AND " + VERSION + " = :version") .bind("uuid", account.getUuid()) .bind("data", mapper.writeValueAsString(account)) + .bind("version", account.getVersion()) + .bind("newVersion", newVersion) .execute(); + + if (rowsModified == 0) { + throw new ContestedOptimisticLockException(); + } + + account.setVersion(newVersion); + } catch (JsonProcessingException e) { throw new IllegalArgumentException(e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDb.java index b43fddfd5..83a89bf2f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDb.java @@ -30,16 +30,20 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.CancellationReason; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.Put; +import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; +import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse; public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountStore { @@ -49,8 +53,8 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt static final String ATTR_ACCOUNT_E164 = "P"; // account, serialized to JSON static final String ATTR_ACCOUNT_DATA = "D"; - - static final String ATTR_MIGRATION_VERSION = "V"; + // internal version for optimistic locking + static final String ATTR_VERSION = "V"; private final DynamoDbClient client; private final DynamoDbAsyncClient asyncClient; @@ -122,11 +126,19 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt ByteBuffer actualAccountUuid = phoneNumberConstraintCancellationReason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer(); account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid)); + final int version = get(account.getUuid()).get().getVersion(); + account.setVersion(version); + update(account); return false; } + if ("TransactionConflict".equals(accountCancellationReason.code())) { + // this should only happen during concurrent update()s for an account migration + throw new ContestedOptimisticLockException(); + } + // this shouldn’t happen throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e)); } @@ -146,7 +158,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid), ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()), ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), - ATTR_MIGRATION_VERSION, AttributeValues.fromInt(account.getDynamoDbMigrationVersion()))) + ATTR_VERSION, AttributeValues.fromInt(account.getVersion()))) .build()) .build(); } @@ -172,28 +184,44 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt } @Override - public void update(Account account) { + public void update(Account account) throws ContestedOptimisticLockException { UPDATE_TIMER.record(() -> { UpdateItemRequest updateItemRequest; try { updateItemRequest = UpdateItemRequest.builder() .tableName(accountsTableName) .key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) - .updateExpression("SET #data = :data, #version = :version") - .conditionExpression("attribute_exists(#number)") + .updateExpression("SET #data = :data ADD #version :version_increment") + .conditionExpression("attribute_exists(#number) AND #version = :version") .expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164, "#data", ATTR_ACCOUNT_DATA, - "#version", ATTR_MIGRATION_VERSION)) + "#version", ATTR_VERSION)) .expressionAttributeValues(Map.of( ":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)), - ":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion()))) + ":version", AttributeValues.fromInt(account.getVersion()), + ":version_increment", AttributeValues.fromInt(1))) + .returnValues(ReturnValue.UPDATED_NEW) .build(); } catch (JsonProcessingException e) { throw new IllegalArgumentException(e); } - client.updateItem(updateItemRequest); + try { + UpdateItemResponse response = client.updateItem(updateItemRequest); + + account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1)); + } catch (final TransactionConflictException e) { + + throw new ContestedOptimisticLockException(); + + } catch (final ConditionalCheckFailedException e) { + + // the exception doesn’t give details about which condition failed, + // but we can infer it was an optimistic locking failure if the UUID is known + throw get(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e; + } + }); } @@ -343,9 +371,9 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt .conditionExpression("attribute_not_exists(#uuid) OR (attribute_exists(#uuid) AND #version < :version)") .expressionAttributeNames(Map.of( "#uuid", KEY_ACCOUNT_UUID, - "#version", ATTR_MIGRATION_VERSION)) + "#version", ATTR_VERSION)) .expressionAttributeValues(Map.of( - ":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))); + ":version", AttributeValues.fromInt(account.getVersion())))); final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() .transactItems(phoneNumberConstraintPut, accountPut).build(); @@ -395,6 +423,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class); account.setNumber(item.get(ATTR_ACCOUNT_E164).s()); account.setUuid(UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer())); + account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n())); return account; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 0d1959037..e25d1e245 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -26,6 +26,8 @@ import java.util.UUID; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import net.logstash.logback.argument.StructuredArguments; import org.apache.commons.lang3.StringUtils; @@ -40,7 +42,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; -import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; public class AccountsManager { @@ -119,7 +120,6 @@ public class AccountsManager { this.mapper = SystemMapper.getMapper(); this.migrationComparisonMapper = mapper.copy(); - migrationComparisonMapper.addMixIn(Account.class, AccountComparisonMixin.class); migrationComparisonMapper.addMixIn(Device.class, DeviceComparisonMixin.class); this.dynamicConfigurationManager = dynamicConfigurationManager; @@ -169,25 +169,86 @@ public class AccountsManager { } } - public void update(Account account) { + public Account update(Account account, Consumer updater) { + + final Account updatedAccount; + try (Timer.Context ignored = updateTimer.time()) { - account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1); - redisSet(account); - databaseUpdate(account); + updater.accept(account); + + { + // optimistically increment version + final int originalVersion = account.getVersion(); + account.setVersion(originalVersion + 1); + redisSet(account); + account.setVersion(originalVersion); + } + + final UUID uuid = account.getUuid(); + + updatedAccount = updateWithRetries(account, updater, this::databaseUpdate, () -> databaseGet(uuid).get()); if (dynamoWriteEnabled()) { runSafelyAndRecordMetrics(() -> { - try { - dynamoUpdate(account); - } catch (final ConditionalCheckFailedException e) { - dynamoCreate(account); + + final Optional dynamoAccount = dynamoGet(uuid); + if (dynamoAccount.isPresent()) { + updater.accept(dynamoAccount.get()); + Account dynamoUpdatedAccount = updateWithRetries(dynamoAccount.get(), + updater, + this::dynamoUpdate, + () -> dynamoGet(uuid).get()); + + return Optional.of(dynamoUpdatedAccount); } - return true; - }, Optional.of(account.getUuid()), true, - (databaseSuccess, dynamoSuccess) -> Optional.empty(), // both values are always true + + return Optional.empty(); + }, Optional.of(uuid), Optional.of(updatedAccount), + this::compareAccounts, "update"); } + + // set the cache again, so that all updates are coalesced + redisSet(updatedAccount); } + + return updatedAccount; + } + + private Account updateWithRetries(Account account, Consumer updater, Consumer persister, Supplier retriever) { + + final int maxTries = 10; + int tries = 0; + + while (tries < maxTries) { + + try { + persister.accept(account); + + final Account updatedAccount; + try { + updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class); + updatedAccount.setUuid(account.getUuid()); + } catch (final IOException e) { + // this should really, truly, never happen + throw new IllegalArgumentException(e); + } + + account.markStale(); + + return updatedAccount; + } catch (final ContestedOptimisticLockException e) { + tries++; + account = retriever.get(); + updater.accept(account); + } + } + + throw new OptimisticLockRetryLimitExceededException(); + } + + public Account updateDevice(Account account, long deviceId, Consumer deviceUpdater) { + return update(account, a -> a.getDevice(deviceId).ifPresent(deviceUpdater)); } public Optional get(AmbiguousIdentifier identifier) { @@ -445,6 +506,10 @@ public class AccountsManager { return Optional.of("number"); } + if (databaseAccount.getVersion() != dynamoAccount.getVersion()) { + return Optional.of("version"); + } + if (!Objects.equals(databaseAccount.getIdentityKey(), dynamoAccount.getIdentityKey())) { return Optional.of("identityKey"); } @@ -566,13 +631,6 @@ public class AccountsManager { .collect(Collectors.joining(" -> ")); } - private static abstract class AccountComparisonMixin extends Account { - - @JsonIgnore - private int dynamoDbMigrationVersion; - - } - private static abstract class DeviceComparisonMixin extends Device { @JsonIgnore diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ContestedOptimisticLockException.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ContestedOptimisticLockException.java new file mode 100644 index 000000000..7b961d89c --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ContestedOptimisticLockException.java @@ -0,0 +1,13 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +public class ContestedOptimisticLockException extends RuntimeException { + + public ContestedOptimisticLockException() { + super(null, null, true, false); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/OptimisticLockRetryLimitExceededException.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/OptimisticLockRetryLimitExceededException.java new file mode 100644 index 000000000..1e608ed09 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/OptimisticLockRetryLimitExceededException.java @@ -0,0 +1,10 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +public class OptimisticLockRetryLimitExceededException extends RuntimeException { + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java index 60b311e83..c522999d4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PushFeedbackProcessor.java @@ -5,20 +5,20 @@ package org.whispersystems.textsecuregcm.storage; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; -import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; -import org.whispersystems.textsecuregcm.util.Constants; -import org.whispersystems.textsecuregcm.util.Util; - import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; - -import static com.codahale.metrics.MetricRegistry.name; +import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; +import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Util; public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener { @@ -47,36 +47,42 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener { for (Account account : chunkAccounts) { boolean update = false; - for (Device device : account.getDevices()) { - if (device.getUninstalledFeedbackTimestamp() != 0 && - device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis()) - { - if (device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis()) { - if (!Util.isEmpty(device.getApnId())) { - if (device.getId() == 1) { - device.setUserAgent("OWI"); - } else { - device.setUserAgent("OWP"); - } - } else if (!Util.isEmpty(device.getGcmId())) { - device.setUserAgent("OWA"); - } - device.setGcmId(null); - device.setApnId(null); - device.setVoipApnId(null); - device.setFetchesMessages(false); + final Set devices = account.getDevices(); + for (Device device : devices) { + if (deviceNeedsUpdate(device)) { + if (deviceExpired(device)) { expired.mark(); } else { - device.setUninstalledFeedbackTimestamp(0); recovered.mark(); } - update = true; } } if (update) { - accountsManager.update(account); + account = accountsManager.update(account, a -> { + for (Device device: a.getDevices()) { + if (deviceNeedsUpdate(device)) { + if (deviceExpired(device)) { + if (!Util.isEmpty(device.getApnId())) { + if (device.getId() == 1) { + device.setUserAgent("OWI"); + } else { + device.setUserAgent("OWP"); + } + } else if (!Util.isEmpty(device.getGcmId())) { + device.setUserAgent("OWA"); + } + device.setGcmId(null); + device.setApnId(null); + device.setVoipApnId(null); + device.setFetchesMessages(false); + } else { + device.setUninstalledFeedbackTimestamp(0); + } + } + } + }); directoryUpdateAccounts.add(account); } } @@ -85,4 +91,13 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener { directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts); } } + + private boolean deviceNeedsUpdate(final Device device) { + return device.getUninstalledFeedbackTimestamp() != 0 && + device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis(); + } + + private boolean deviceExpired(final Device device) { + return device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis(); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java index 450699bfd..33c3a535e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/AccountRowMapper.java @@ -27,6 +27,7 @@ public class AccountRowMapper implements RowMapper { Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class); account.setNumber(resultSet.getString(Accounts.NUMBER)); account.setUuid(UUID.fromString(resultSet.getString(Accounts.UID))); + account.setVersion(resultSet.getInt(Accounts.VERSION)); return account; } catch (IOException e) { throw new SQLException(e); diff --git a/service/src/main/resources/accountsdb.xml b/service/src/main/resources/accountsdb.xml index b7e2acba6..a892e9cfc 100644 --- a/service/src/main/resources/accountsdb.xml +++ b/service/src/main/resources/accountsdb.xml @@ -375,4 +375,10 @@ + + + + + + diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDbTest.java index 02d3b740c..5cb0b51b9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsDynamoDbTest.java @@ -50,6 +50,7 @@ import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanResponse; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; +import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; class AccountsDynamoDbTest { @@ -211,6 +212,10 @@ class AccountsDynamoDbTest { verifyStoredState("+14151112222", account.getUuid(), account); + account.setProfileName("name"); + + accountsDynamoDb.update(account); + UUID secondUuid = UUID.randomUUID(); device = generateDevice(1); @@ -252,13 +257,44 @@ class AccountsDynamoDbTest { assertThatThrownBy(() -> accountsDynamoDb.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class); - account.setDynamoDbMigrationVersion(5); + account.setProfileName("name"); + + accountsDynamoDb.update(account); + + assertThat(account.getVersion()).isEqualTo(2); + + verifyStoredState("+14151112222", account.getUuid(), account); + + account.setVersion(1); + + assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class); + + account.setVersion(2); + account.setProfileName("name2"); accountsDynamoDb.update(account); verifyStoredState("+14151112222", account.getUuid(), account); } + @Test + void testUpdateWithMockTransactionConflictException() { + + final DynamoDbClient dynamoDbClient = mock(DynamoDbClient.class); + accountsDynamoDb = new AccountsDynamoDb(dynamoDbClient, mock(DynamoDbAsyncClient.class), + new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()), + dynamoDbExtension.getTableName(), NUMBERS_TABLE_NAME, mock(MigrationDeletedAccounts.class), + mock(MigrationRetryAccounts.class)); + + when(dynamoDbClient.updateItem(any(UpdateItemRequest.class))) + .thenThrow(TransactionConflictException.class); + + Device device = generateDevice (1 ); + Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device)); + + assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class); + } + @Test void testRetrieveFrom() { List users = new ArrayList<>(); @@ -463,7 +499,7 @@ class AccountsDynamoDbTest { assertThat(migrated).isFalse(); verifyStoredState("+14151112222", firstUuid, account); - account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1); + account.setVersion(account.getVersion() + 1); migrated = accountsDynamoDb.migrate(account).get(); @@ -504,8 +540,8 @@ class AccountsDynamoDbTest { String data = new String(get.item().get(AccountsDynamoDb.ATTR_ACCOUNT_DATA).b().asByteArray(), StandardCharsets.UTF_8); assertThat(data).isNotEmpty(); - assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_MIGRATION_VERSION, -1)) - .isEqualTo(expecting.getDynamoDbMigrationVersion()); + assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_VERSION, -1)) + .isEqualTo(expecting.getVersion()); Account result = AccountsDynamoDb.fromItem(get.item()); verifyStoredState(number, uuid, result, expecting); @@ -518,6 +554,7 @@ class AccountsDynamoDbTest { assertThat(result.getNumber()).isEqualTo(number); assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen()); assertThat(result.getUuid()).isEqualTo(uuid); + assertThat(result.getVersion()).isEqualTo(expecting.getVersion()); assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue(); for (Device expectingDevice : expecting.getDevices()) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java new file mode 100644 index 000000000..c381a60fc --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -0,0 +1,274 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.opentable.db.postgres.embedded.LiquibasePreparer; +import com.opentable.db.postgres.junit5.EmbeddedPostgresExtension; +import com.opentable.db.postgres.junit5.PreparedDbExtension; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; +import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; +import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; +import org.whispersystems.textsecuregcm.tests.util.JsonHelpers; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import org.whispersystems.textsecuregcm.util.Pair; +import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; + +class AccountsManagerConcurrentModificationIntegrationTest { + + @RegisterExtension + static PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml")); + + private static final String ACCOUNTS_TABLE_NAME = "accounts_test"; + private static final String NUMBERS_TABLE_NAME = "numbers_test"; + + @RegisterExtension + static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder() + .tableName(ACCOUNTS_TABLE_NAME) + .hashKey(AccountsDynamoDb.KEY_ACCOUNT_UUID) + .attributeDefinition(AttributeDefinition.builder() + .attributeName(AccountsDynamoDb.KEY_ACCOUNT_UUID) + .attributeType(ScalarAttributeType.B) + .build()) + .build(); + + private Accounts accounts; + + private AccountsDynamoDb accountsDynamoDb; + + private AccountsManager accountsManager; + + private RedisAdvancedClusterCommands commands; + + private Executor mutationExecutor = new ThreadPoolExecutor(20, 20, 5, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20)); + + @BeforeEach + void setup() { + + { + CreateTableRequest createNumbersTableRequest = CreateTableRequest.builder() + .tableName(NUMBERS_TABLE_NAME) + .keySchema(KeySchemaElement.builder() + .attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164) + .keyType(KeyType.HASH) + .build()) + .attributeDefinitions(AttributeDefinition.builder() + .attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164) + .attributeType(ScalarAttributeType.S) + .build()) + .provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT) + .build(); + + dynamoDbExtension.getDynamoDbClient().createTable(createNumbersTableRequest); + } + + accountsDynamoDb = new AccountsDynamoDb( + dynamoDbExtension.getDynamoDbClient(), + dynamoDbExtension.getDynamoDbAsyncClient(), + new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()), + dynamoDbExtension.getTableName(), + NUMBERS_TABLE_NAME, + mock(MigrationDeletedAccounts.class), + mock(MigrationRetryAccounts.class)); + + { + final CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration(); + circuitBreakerConfiguration.setIgnoredExceptions(List.of("org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException")); + FaultTolerantDatabase faultTolerantDatabase = new FaultTolerantDatabase("accountsTest", + Jdbi.create(db.getTestDatabase()), + circuitBreakerConfiguration); + + accounts = new Accounts(faultTolerantDatabase); + } + + { + final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + + DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + + final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + + final DynamicAccountsDynamoDbMigrationConfiguration config = dynamicConfiguration + .getAccountsDynamoDbMigrationConfiguration(); + + config.setDeleteEnabled(true); + config.setReadEnabled(true); + config.setWriteEnabled(true); + + when(experimentEnrollmentManager.isEnrolled(any(UUID.class), anyString())).thenReturn(true); + + commands = mock(RedisAdvancedClusterCommands.class); + + accountsManager = new AccountsManager( + accounts, + accountsDynamoDb, + RedisClusterHelper.buildMockRedisCluster(commands), + mock(DeletedAccounts.class), + mock(DirectoryQueue.class), + mock(KeysDynamoDb.class), + mock(MessagesManager.class), + mock(UsernamesManager.class), + mock(ProfilesManager.class), + mock(SecureStorageClient.class), + mock(SecureBackupClient.class), + experimentEnrollmentManager, + dynamicConfigurationManager); + } + } + + @Test + void testConcurrentUpdate() throws IOException { + + final UUID uuid = UUID.randomUUID(); + + accountsManager.create(generateAccount("+14155551212", uuid)); + + final String profileName = "name"; + final String avatar = "avatar"; + final boolean discoverableByPhoneNumber = false; + final String currentProfileVersion = "cpv"; + final String identityKey = "ikey"; + final byte[] unidentifiedAccessKey = new byte[]{1}; + final String pin = "1234"; + final String registrationLock = "reglock"; + final AuthenticationCredentials credentials = new AuthenticationCredentials(registrationLock); + final boolean unrestrictedUnidentifiedAccess = true; + final long lastSeen = Instant.now().getEpochSecond(); + + CompletableFuture.allOf( + modifyAccount(uuid, account -> account.setProfileName(profileName)), + modifyAccount(uuid, account -> account.setAvatar(avatar)), + modifyAccount(uuid, account -> account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber)), + modifyAccount(uuid, account -> account.setCurrentProfileVersion(currentProfileVersion)), + modifyAccount(uuid, account -> account.setIdentityKey(identityKey)), + modifyAccount(uuid, account -> account.setUnidentifiedAccessKey(unidentifiedAccessKey)), + modifyAccount(uuid, account -> account.setPin(pin)), + modifyAccount(uuid, account -> account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt())), + modifyAccount(uuid, account -> account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess)), + modifyDevice(uuid, Device.MASTER_ID, device-> device.setLastSeen(lastSeen)), + modifyDevice(uuid, Device.MASTER_ID, device-> device.setName("deviceName")) + ).join(); + + final Account managerAccount = accountsManager.get(uuid).get(); + final Account dbAccount = accounts.get(uuid).get(); + final Account dynamoAccount = accountsDynamoDb.get(uuid).get(); + + final Account redisAccount = getLastAccountFromRedisMock(commands); + + Stream.of( + new Pair<>("manager", managerAccount), + new Pair<>("db", dbAccount), + new Pair<>("dynamo", dynamoAccount), + new Pair<>("redis", redisAccount) + ).forEach(pair -> + verifyAccount(pair.first(), pair.second(), profileName, avatar, discoverableByPhoneNumber, + currentProfileVersion, identityKey, unidentifiedAccessKey, pin, registrationLock, + unrestrictedUnidentifiedAccess, lastSeen) + ); + } + + private Account getLastAccountFromRedisMock(RedisAdvancedClusterCommands commands) throws IOException { + ArgumentCaptor redisSetArgumentCapture = ArgumentCaptor.forClass(String.class); + + verify(commands, atLeast(20)).set(anyString(), redisSetArgumentCapture.capture()); + + return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class); + } + + private void verifyAccount(final String name, final Account account, final String profileName, final String avatar, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final String identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAcces, final long lastSeen) { + + assertAll(name, + () -> assertEquals(profileName, account.getProfileName()), + () -> assertEquals(avatar, account.getAvatar()), + () -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()), + () -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().get()), + () -> assertEquals(identityKey, account.getIdentityKey()), + () -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().get()), + () -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock, pin)), + () -> assertEquals(unrestrictedUnidentifiedAcces, account.isUnrestrictedUnidentifiedAccess()) + ); + } + + private CompletableFuture modifyAccount(final UUID uuid, final Consumer accountMutation) { + + return CompletableFuture.runAsync(() -> { + final Account account = accountsManager.get(uuid).get(); + accountsManager.update(account, accountMutation); + }, mutationExecutor); + } + + private CompletableFuture modifyDevice(final UUID uuid, final long deviceId, final Consumer deviceMutation) { + + return CompletableFuture.runAsync(() -> { + final Account account = accountsManager.get(uuid).get(); + accountsManager.updateDevice(account, deviceId, deviceMutation); + }, mutationExecutor); + } + + private Account generateAccount(String number, UUID uuid) { + Device device = generateDevice(1); + return generateAccount(number, uuid, Collections.singleton(device)); + } + + private Account generateAccount(String number, UUID uuid, Set devices) { + byte[] unidentifiedAccessKey = new byte[16]; + Random random = new Random(System.currentTimeMillis()); + Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255)); + + return new Account(number, uuid, devices, unidentifiedAccessKey); + } + + private Device generateDevice(long id) { + Random random = new Random(System.currentTimeMillis()); + SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt()); + return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(), + "testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt() , 0, new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), + random.nextBoolean(), random.nextBoolean())); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/BaseAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/BaseAccountAuthenticatorTest.java index b5ffd4009..612f5a4ce 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/BaseAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/auth/BaseAccountAuthenticatorTest.java @@ -5,104 +5,120 @@ package org.whispersystems.textsecuregcm.tests.auth; -import org.junit.Before; -import org.junit.Test; -import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.tests.util.AuthHelper; - -import java.time.Clock; -import java.time.Instant; -import java.util.Random; -import java.util.Set; - import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class BaseAccountAuthenticatorTest { +import java.time.Clock; +import java.time.Instant; +import java.util.Random; +import java.util.Set; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; - private final Random random = new Random(867_5309L); - private final long today = 1590451200000L; - private final long yesterday = today - 86_400_000L; - private final long oldTime = yesterday - 86_400_000L; - private final long currentTime = today + 68_000_000L; +class BaseAccountAuthenticatorTest { - private AccountsManager accountsManager; - private BaseAccountAuthenticator baseAccountAuthenticator; - private Clock clock; - private Account acct1; - private Account acct2; - private Account oldAccount; + private final Random random = new Random(867_5309L); + private final long today = 1590451200000L; + private final long yesterday = today - 86_400_000L; + private final long oldTime = yesterday - 86_400_000L; + private final long currentTime = today + 68_000_000L; - @Before - public void setup() { - accountsManager = mock(AccountsManager.class); - clock = mock(Clock.class); - baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock); + private AccountsManager accountsManager; + private BaseAccountAuthenticator baseAccountAuthenticator; + private Clock clock; + private Account acct1; + private Account acct2; + private Account oldAccount; - acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, - null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null); - acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, - null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null); - oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, - null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null); - } + @BeforeEach + void setup() { + accountsManager = mock(AccountsManager.class); + clock = mock(Clock.class); + baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock); - @Test - public void testUpdateLastSeenMiddleOfDay() { - when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime)); + acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, + null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null); + acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, + null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null); + oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null, + null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null); - baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get()); - baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get()); + AccountsHelper.setupMockUpdate(accountsManager); + } - verify(accountsManager, never()).update(acct1); - verify(accountsManager).update(acct2); + @Test + void testUpdateLastSeenMiddleOfDay() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime)); - assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday); - assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today); - } + final Device device1 = acct1.getDevices().stream().findFirst().get(); + final Device device2 = acct2.getDevices().stream().findFirst().get(); - @Test - public void testUpdateLastSeenStartOfDay() { - when(clock.instant()).thenReturn(Instant.ofEpochMilli(today)); + baseAccountAuthenticator.updateLastSeen(acct1, device1); + baseAccountAuthenticator.updateLastSeen(acct2, device2); - baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get()); - baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get()); + verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any()); + verify(accountsManager).updateDevice(eq(acct2), anyLong(), any()); - verify(accountsManager, never()).update(acct1); - verify(accountsManager, never()).update(acct2); + assertThat(device1.getLastSeen()).isEqualTo(yesterday); + assertThat(device2.getLastSeen()).isEqualTo(today); + } - assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday); - assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday); - } + @Test + void testUpdateLastSeenStartOfDay() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(today)); - @Test - public void testUpdateLastSeenEndOfDay() { - when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1)); + final Device device1 = acct1.getDevices().stream().findFirst().get(); + final Device device2 = acct2.getDevices().stream().findFirst().get(); - baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get()); - baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get()); + baseAccountAuthenticator.updateLastSeen(acct1, device1); + baseAccountAuthenticator.updateLastSeen(acct2, device2); - verify(accountsManager).update(acct1); - verify(accountsManager).update(acct2); + verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any()); + verify(accountsManager, never()).updateDevice(eq(acct2), anyLong(), any()); - assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today); - assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today); - } + assertThat(device1.getLastSeen()).isEqualTo(yesterday); + assertThat(device2.getLastSeen()).isEqualTo(yesterday); + } - @Test - public void testNeverWriteYesterday() { - when(clock.instant()).thenReturn(Instant.ofEpochMilli(today)); + @Test + void testUpdateLastSeenEndOfDay() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1)); - baseAccountAuthenticator.updateLastSeen(oldAccount, oldAccount.getDevices().stream().findFirst().get()); + final Device device1 = acct1.getDevices().stream().findFirst().get(); + final Device device2 = acct2.getDevices().stream().findFirst().get(); - verify(accountsManager).update(oldAccount); + baseAccountAuthenticator.updateLastSeen(acct1, device1); + baseAccountAuthenticator.updateLastSeen(acct2, device2); - assertThat(oldAccount.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today); - } + verify(accountsManager).updateDevice(eq(acct1), anyLong(), any()); + verify(accountsManager).updateDevice(eq(acct2), anyLong(), any()); + + assertThat(device1.getLastSeen()).isEqualTo(today); + assertThat(device2.getLastSeen()).isEqualTo(today); + } + + @Test + void testNeverWriteYesterday() { + when(clock.instant()).thenReturn(Instant.ofEpochMilli(today)); + + final Device device = oldAccount.getDevices().stream().findFirst().get(); + + baseAccountAuthenticator.updateLastSeen(oldAccount, device); + + verify(accountsManager).updateDevice(eq(oldAccount), anyLong(), any()); + + assertThat(device.getLastSeen()).isEqualTo(today); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index 30be19831..ee0064c6d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.*; +import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; @@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.storage.UsernamesManager; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.Hex; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -171,6 +173,8 @@ class AccountControllerTest { new SecureRandom().nextBytes(registration_lock_key); AuthenticationCredentials registrationLockCredentials = new AuthenticationCredentials(Hex.toStringCondensed(registration_lock_key)); + AccountsHelper.setupMockUpdate(accountsManager); + when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVoiceDestinationDailyLimiter()).thenReturn(rateLimiter); @@ -1352,7 +1356,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("c00lz0rz")); - verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class)); } @@ -1368,7 +1372,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000")); - verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class)); } @@ -1385,7 +1389,7 @@ class AccountControllerTest { verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first")); verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second")); - verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class)); } @@ -1402,7 +1406,7 @@ class AccountControllerTest { verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first")); verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null); - verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class)); } @@ -1419,7 +1423,7 @@ class AccountControllerTest { verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("third")); verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("fourth")); - verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT)); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class)); } @@ -1544,7 +1548,7 @@ class AccountControllerTest { .put(Entity.json(new AccountAttributes(false, 2222, null, null, null, true, null))); assertThat(response.getStatus()).isEqualTo(204); - verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.UNDISCOVERABLE_ACCOUNT); + verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.UNDISCOVERABLE_ACCOUNT)); } @Test @@ -1557,7 +1561,7 @@ class AccountControllerTest { .put(Entity.json(new AccountAttributes(false, 2222, null, null, null, false, null))); assertThat(response.getStatus()).isEqualTo(204); - verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.VALID_ACCOUNT); + verify(directoryQueue, times(1)).refreshRegisteredUser(eqUuid(AuthHelper.VALID_ACCOUNT)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 016c9ecff..409bdc176 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -47,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.VerificationCode; @@ -124,6 +126,8 @@ public class DeviceControllerTest { when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty()); when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account)); when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount)); + + AccountsHelper.setupMockUpdate(accountsManager); } @Test @@ -360,7 +364,7 @@ public class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId); - verify(accountsManager, times(1)).update(AuthHelper.VALID_ACCOUNT); + verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any()); verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 3b96d213c..09c540b42 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.Mockito.argThat; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; @@ -15,6 +16,7 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; @@ -61,6 +63,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -114,13 +117,15 @@ class KeysControllerTest { final Device sampleDevice3 = mock(Device.class); final Device sampleDevice4 = mock(Device.class); - Set allDevices = new HashSet() {{ + Set allDevices = new HashSet<>() {{ add(sampleDevice); add(sampleDevice2); add(sampleDevice3); add(sampleDevice4); }}; + AccountsHelper.setupMockUpdate(accounts); + when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID); when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); @@ -142,7 +147,7 @@ class KeysControllerTest { when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2)); when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3)); when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4)); - when(existsAccount.getDevice(22L)).thenReturn(Optional.empty()); + when(existsAccount.getDevice(22L)).thenReturn(Optional.empty()); when(existsAccount.getDevices()).thenReturn(allDevices); when(existsAccount.isEnabled()).thenReturn(true); when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey"); @@ -256,7 +261,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT)); + verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any()); } @Test @@ -271,7 +276,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT)); + verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any()); } @@ -578,7 +583,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(keysDynamoDb).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); + verify(keysDynamoDb).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -587,7 +592,7 @@ class KeysControllerTest { verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar")); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); - verify(accounts).update(AuthHelper.VALID_ACCOUNT); + verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); } @Test @@ -612,7 +617,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(keysDynamoDb).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); + verify(keysDynamoDb).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -621,7 +626,7 @@ class KeysControllerTest { verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq("barbar")); verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); - verify(accounts).update(AuthHelper.DISABLED_ACCOUNT); + verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java index 17fa52d79..26703c115 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java @@ -20,7 +20,8 @@ import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; -import io.dropwizard.testing.junit.ResourceTestRule; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; import java.util.Collections; import java.util.Optional; import java.util.Set; @@ -29,9 +30,10 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; -import org.junit.Before; -import org.junit.ClassRule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatcher; import org.signal.zkgroup.InvalidInputException; @@ -57,13 +59,15 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.UsernamesManager; import org.whispersystems.textsecuregcm.storage.VersionedProfile; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; -public class ProfileControllerTest { +@ExtendWith(DropwizardExtensionsSupport.class) +class ProfileControllerTest { private static AccountsManager accountsManager = mock(AccountsManager.class ); private static ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -82,30 +86,30 @@ public class ProfileControllerTest { private Account profileAccount; + private static final ResourceExtension resources = ResourceExtension.builder() + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new ProfileController(rateLimiters, + accountsManager, + profilesManager, + usernamesManager, + dynamicConfigurationManager, + s3client, + postPolicyGenerator, + policySigner, + "profilesBucket", + zkProfileOperations, + true)) + .build(); - @ClassRule - public static final ResourceTestRule resources = ResourceTestRule.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new ProfileController(rateLimiters, - accountsManager, - profilesManager, - usernamesManager, - dynamicConfigurationManager, - s3client, - postPolicyGenerator, - policySigner, - "profilesBucket", - zkProfileOperations, - true)) - .build(); - - @Before - public void setup() throws Exception { + @BeforeEach + void setup() throws Exception { reset(s3client); + AccountsHelper.setupMockUpdate(accountsManager); + dynamicPaymentsConfiguration = mock(DynamicPaymentsConfiguration.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); @@ -161,8 +165,13 @@ public class ProfileControllerTest { clearInvocations(profilesManager); } + @AfterEach + void teardown() { + reset(accountsManager); + } + @Test - public void testProfileGetByUuid() throws RateLimitExceededException { + void testProfileGetByUuid() throws RateLimitExceededException { Profile profile= resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .request() @@ -180,7 +189,7 @@ public class ProfileControllerTest { } @Test - public void testProfileGetByNumber() throws RateLimitExceededException { + void testProfileGetByNumber() throws RateLimitExceededException { Profile profile= resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO) .request() @@ -201,7 +210,7 @@ public class ProfileControllerTest { } @Test - public void testProfileGetByUsername() throws RateLimitExceededException { + void testProfileGetByUsername() throws RateLimitExceededException { Profile profile= resources.getJerseyTest() .target("/v1/profile/username/n00bkiller") .request() @@ -220,7 +229,7 @@ public class ProfileControllerTest { } @Test - public void testProfileGetUnauthorized() { + void testProfileGetUnauthorized() { Response response = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO) .request() @@ -230,7 +239,7 @@ public class ProfileControllerTest { } @Test - public void testProfileGetByUsernameUnauthorized() { + void testProfileGetByUsernameUnauthorized() { Response response = resources.getJerseyTest() .target("/v1/profile/username/n00bkiller") .request() @@ -241,7 +250,7 @@ public class ProfileControllerTest { @Test - public void testProfileGetByUsernameNotFound() throws RateLimitExceededException { + void testProfileGetByUsernameNotFound() throws RateLimitExceededException { Response response = resources.getJerseyTest() .target("/v1/profile/username/n00bkillerzzzzz") .request() @@ -256,7 +265,7 @@ public class ProfileControllerTest { @Test - public void testProfileGetDisabled() { + void testProfileGetDisabled() { Response response = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_NUMBER_TWO) .request() @@ -267,7 +276,7 @@ public class ProfileControllerTest { } @Test - public void testProfileCapabilities() { + void testProfileCapabilities() { Profile profile= resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_NUMBER) .request() @@ -293,7 +302,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfileNameDeprecated() { + void testSetProfileNameDeprecated() { Response response = resources.getJerseyTest() .target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012") .request() @@ -302,11 +311,11 @@ public class ProfileControllerTest { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager, times(1)).update(any(Account.class)); + verify(accountsManager, times(1)).update(any(Account.class), any()); } @Test - public void testSetProfileNameExtendedDeprecated() { + void testSetProfileNameExtendedDeprecated() { Response response = resources.getJerseyTest() .target("/v1/profile/name/123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678") .request() @@ -315,11 +324,11 @@ public class ProfileControllerTest { assertThat(response.getStatus()).isEqualTo(204); - verify(accountsManager, times(1)).update(any(Account.class)); + verify(accountsManager, times(1)).update(any(Account.class), any()); } @Test - public void testSetProfileNameWrongSizeDeprecated() { + void testSetProfileNameWrongSizeDeprecated() { Response response = resources.getJerseyTest() .target("/v1/profile/name/1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890") .request() @@ -333,7 +342,7 @@ public class ProfileControllerTest { ///// @Test - public void testSetProfileWantAvatarUpload() throws InvalidInputException { + void testSetProfileWantAvatarUpload() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID); ProfileAvatarUploadAttributes uploadAttributes = resources.getJerseyTest() @@ -358,7 +367,7 @@ public class ProfileControllerTest { assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); } @Test - public void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException { + void testSetProfileWantAvatarUploadWithBadProfileSize() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID); Response response = resources.getJerseyTest() @@ -372,7 +381,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfileWithoutAvatarUpload() throws InvalidInputException { + void testSetProfileWithoutAvatarUpload() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID); clearInvocations(AuthHelper.VALID_ACCOUNT_TWO); @@ -406,7 +415,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException { + void testSetProfileWithAvatarUploadAndPreviousAvatar() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO); ProfileAvatarUploadAttributes uploadAttributes= resources.getJerseyTest() @@ -430,7 +439,7 @@ public class ProfileControllerTest { assertThat(profileArgumentCaptor.getValue().getAbout()).isNull(); } @Test - public void testSetProfileExtendedName() throws InvalidInputException { + void testSetProfileExtendedName() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID_TWO); final String name = RandomStringUtils.randomAlphabetic(380); @@ -456,7 +465,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfileEmojiAndBioText() throws InvalidInputException { + void testSetProfileEmojiAndBioText() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID); clearInvocations(AuthHelper.VALID_ACCOUNT_TWO); @@ -495,7 +504,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfilePaymentAddress() throws InvalidInputException { + void testSetProfilePaymentAddress() throws InvalidInputException { when(dynamicPaymentsConfiguration.getAllowedCountryCodes()) .thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO))); @@ -536,7 +545,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException { + void testSetProfilePaymentAddressCountryNotAllowed() throws InvalidInputException { ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(AuthHelper.VALID_UUID); clearInvocations(AuthHelper.VALID_ACCOUNT_TWO); @@ -557,7 +566,7 @@ public class ProfileControllerTest { } @Test - public void testGetProfileByVersion() throws RateLimitExceededException { + void testGetProfileByVersion() throws RateLimitExceededException { Profile profile = resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO + "/validversion") .request() @@ -582,7 +591,7 @@ public class ProfileControllerTest { } @Test - public void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException { + void testSetProfileUpdatesAccountCurrentVersion() throws InvalidInputException { when(dynamicPaymentsConfiguration.getAllowedCountryCodes()) .thenReturn(Set.of(Util.getCountryCode(AuthHelper.VALID_NUMBER_TWO))); @@ -606,7 +615,7 @@ public class ProfileControllerTest { } @Test - public void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() { + void testGetProfileReturnsNoPaymentAddressIfCurrentVersionMismatch() { when(profilesManager.get(AuthHelper.VALID_UUID_TWO, "validversion")).thenReturn( Optional.of(new VersionedProfile(null, null, null, null, null, "paymentaddress", null))); Profile profile = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/push/GCMSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/push/GCMSenderTest.java index 710a2078b..3c8940e46 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/push/GCMSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/push/GCMSenderTest.java @@ -23,6 +23,7 @@ import org.whispersystems.textsecuregcm.push.GcmMessage; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService; import org.whispersystems.textsecuregcm.util.Util; @@ -40,6 +41,8 @@ public class GCMSenderTest { when(successResult.hasCanonicalRegistrationId()).thenReturn(false); when(successResult.isSuccess()).thenReturn(true); + AccountsHelper.setupMockUpdate(accountsManager); + GcmMessage message = new GcmMessage("foo", "+12223334444", 1, GcmMessage.Type.NOTIFICATION, Optional.empty()); GCMSender gcmSender = new GCMSender(executorService, accountsManager, sender); @@ -65,6 +68,8 @@ public class GCMSenderTest { Account destinationAccount = mock(Account.class); Device destinationDevice = mock(Device.class ); + AccountsHelper.setupMockUpdate(accountsManager); + when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice)); when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount)); when(destinationDevice.getGcmId()).thenReturn(gcmId); @@ -85,7 +90,7 @@ public class GCMSenderTest { verify(sender, times(1)).send(any(Message.class)); verify(accountsManager, times(1)).get(eq(destinationNumber)); - verify(accountsManager, times(1)).update(eq(destinationAccount)); + verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any()); verify(destinationDevice, times(1)).setUninstalledFeedbackTimestamp(eq(Util.todayInMillis())); } @@ -107,6 +112,8 @@ public class GCMSenderTest { when(accountsManager.get(destinationNumber)).thenReturn(Optional.of(destinationAccount)); when(destinationDevice.getGcmId()).thenReturn(gcmId); + AccountsHelper.setupMockUpdate(accountsManager); + when(canonicalResult.isInvalidRegistrationId()).thenReturn(false); when(canonicalResult.isUnregistered()).thenReturn(false); when(canonicalResult.hasCanonicalRegistrationId()).thenReturn(true); @@ -124,7 +131,7 @@ public class GCMSenderTest { verify(sender, times(1)).send(any(Message.class)); verify(accountsManager, times(1)).get(eq(destinationNumber)); - verify(accountsManager, times(1)).update(eq(destinationAccount)); + verify(accountsManager, times(1)).updateDevice(eq(destinationAccount), eq(1L), any()); verify(destinationDevice, times(1)).setGcmId(eq(canonicalId)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountTest.java index e3601f4a2..d53ad861d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountTest.java @@ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.tests.storage; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -17,26 +19,26 @@ import java.util.HashSet; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; -public class AccountTest { +class AccountTest { - private final Device oldMasterDevice = mock(Device.class); - private final Device recentMasterDevice = mock(Device.class); - private final Device agingSecondaryDevice = mock(Device.class); + private final Device oldMasterDevice = mock(Device.class); + private final Device recentMasterDevice = mock(Device.class); + private final Device agingSecondaryDevice = mock(Device.class); private final Device recentSecondaryDevice = mock(Device.class); - private final Device oldSecondaryDevice = mock(Device.class); + private final Device oldSecondaryDevice = mock(Device.class); - private final Device gv2CapableDevice = mock(Device.class); - private final Device gv2IncapableDevice = mock(Device.class); + private final Device gv2CapableDevice = mock(Device.class); + private final Device gv2IncapableDevice = mock(Device.class); private final Device gv2IncapableExpiredDevice = mock(Device.class); private final Device gv1MigrationCapableDevice = mock(Device.class); - private final Device gv1MigrationIncapableDevice = mock(Device.class); + private final Device gv1MigrationIncapableDevice = mock(Device.class); private final Device gv1MigrationIncapableExpiredDevice = mock(Device.class); private final Device senderKeyCapableDevice = mock(Device.class); @@ -47,8 +49,8 @@ public class AccountTest { private final Device announcementGroupIncapableDevice = mock(Device.class); private final Device announcementGroupIncapableExpiredDevice = mock(Device.class); - @Before - public void setup() { + @BeforeEach + void setup() { when(oldMasterDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366)); when(oldMasterDevice.isEnabled()).thenReturn(true); when(oldMasterDevice.getId()).thenReturn(Device.MASTER_ID); @@ -119,9 +121,9 @@ public class AccountTest { } @Test - public void testIsEnabled() { - final Device enabledMasterDevice = mock(Device.class); - final Device enabledLinkedDevice = mock(Device.class); + void testIsEnabled() { + final Device enabledMasterDevice = mock(Device.class); + final Device enabledLinkedDevice = mock(Device.class); final Device disabledMasterDevice = mock(Device.class); final Device disabledLinkedDevice = mock(Device.class); @@ -144,7 +146,7 @@ public class AccountTest { } @Test - public void testCapabilities() { + void testCapabilities() { Account uuidCapable = new Account("+14152222222", UUID.randomUUID(), new HashSet() {{ add(gv2CapableDevice); }}, "1234".getBytes()); @@ -165,13 +167,13 @@ public class AccountTest { } @Test - public void testIsTransferSupported() { - final Device transferCapableMasterDevice = mock(Device.class); - final Device nonTransferCapableMasterDevice = mock(Device.class); - final Device transferCapableLinkedDevice = mock(Device.class); + void testIsTransferSupported() { + final Device transferCapableMasterDevice = mock(Device.class); + final Device nonTransferCapableMasterDevice = mock(Device.class); + final Device transferCapableLinkedDevice = mock(Device.class); - final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class); - final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class); + final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class); + final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class); when(transferCapableMasterDevice.getId()).thenReturn(1L); when(transferCapableMasterDevice.isMaster()).thenReturn(true); @@ -213,10 +215,12 @@ public class AccountTest { } @Test - public void testDiscoverableByPhoneNumber() { - final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice), "1234".getBytes()); + void testDiscoverableByPhoneNumber() { + final Account account = new Account("+14152222222", UUID.randomUUID(), Collections.singleton(recentMasterDevice), + "1234".getBytes()); - assertTrue("Freshly-loaded legacy accounts should be discoverable by phone number.", account.isDiscoverableByPhoneNumber()); + assertTrue(account.isDiscoverableByPhoneNumber(), + "Freshly-loaded legacy accounts should be discoverable by phone number."); account.setDiscoverableByPhoneNumber(false); assertFalse(account.isDiscoverableByPhoneNumber()); @@ -226,21 +230,29 @@ public class AccountTest { } @Test - public void isGroupsV2Supported() { - assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); - assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); - assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); + void isGroupsV2Supported() { + assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice), + "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); + assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableExpiredDevice), + "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); + assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv2CapableDevice, gv2IncapableDevice), + "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); } @Test - public void isGv1MigrationSupported() { - assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); - assertFalse(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); - assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); + void isGv1MigrationSupported() { + assertTrue(new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice), + "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); + assertFalse( + new Account("+18005551234", UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice), + "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); + assertTrue(new Account("+18005551234", UUID.randomUUID(), + Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)) + .isGv1MigrationSupported()); } @Test - public void isSenderKeySupported() { + void isSenderKeySupported() { assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue(); assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(senderKeyCapableDevice, senderKeyIncapableDevice), @@ -251,7 +263,7 @@ public class AccountTest { } @Test - public void isAnnouncementGroupSupported() { + void isAnnouncementGroupSupported() { assertThat(new Account("+18005551234", UUID.randomUUID(), Set.of(announcementGroupCapableDevice), "1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue(); @@ -262,4 +274,16 @@ public class AccountTest { Set.of(announcementGroupCapableDevice, announcementGroupIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue(); } + + @Test + void stale() { + final Account account = new Account("+14151234567", UUID.randomUUID(), Collections.emptySet(), new byte[0]); + + assertDoesNotThrow(account::getNumber); + + account.markStale(); + + assertThrows(AssertionError.class, account::getNumber); + assertDoesNotThrow(account::getUuid); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java index 4b8483579..5aa0d41ed 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsManagerTest.java @@ -6,11 +6,14 @@ package org.whispersystems.textsecuregcm.tests.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -22,13 +25,17 @@ import static org.mockito.Mockito.when; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import java.io.IOException; import java.util.HashSet; import java.util.Optional; import java.util.UUID; +import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.SignedPreKey; @@ -41,6 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsDynamoDb; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException; import org.whispersystems.textsecuregcm.storage.DeletedAccounts; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; @@ -48,14 +56,22 @@ import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.UsernamesManager; +import org.whispersystems.textsecuregcm.tests.util.JsonHelpers; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; -import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; class AccountsManagerTest { private DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); private ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); + private static final Answer ACCOUNT_UPDATE_ANSWER = (answer) -> { + // it is implicit in the update() contract is that a successful call will + // result in an incremented version + final Account updatedAccount = answer.getArgument(0, Account.class); + updatedAccount.setVersion(updatedAccount.getVersion() + 1); + return null; + }; + @BeforeEach void setup() { @@ -326,7 +342,7 @@ class AccountsManagerTest { @ParameterizedTest @ValueSource(booleans = {true, false}) - void testUpdate_dynamoDbMigration(boolean dynamoEnabled) { + void testUpdate_dynamoDbMigration(boolean dynamoEnabled) throws IOException { RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); Accounts accounts = mock(Accounts.class); @@ -335,35 +351,59 @@ class AccountsManagerTest { DirectoryQueue directoryQueue = mock(DirectoryQueue.class); KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); MessagesManager messagesManager = mock(MessagesManager.class); - UsernamesManager usernamesManager = mock(UsernamesManager.class); - ProfilesManager profilesManager = mock(ProfilesManager.class); - SecureBackupClient secureBackupClient = mock(SecureBackupClient.class); - SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); - UUID uuid = UUID.randomUUID(); - Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]); + UsernamesManager usernamesManager = mock(UsernamesManager.class); + ProfilesManager profilesManager = mock(ProfilesManager.class); + SecureBackupClient secureBackupClient = mock(SecureBackupClient.class); + SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); + UUID uuid = UUID.randomUUID(); + Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]); enableDynamo(dynamoEnabled); when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + // database fetches should always return new instances + when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16]))); + when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16]))); + doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any(Account.class)); AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager); - assertEquals(0, account.getDynamoDbMigrationVersion()); + Account updatedAccount = accountsManager.update(account, a -> a.setProfileName("name")); - accountsManager.update(account); + assertThrows(AssertionError.class, account::getProfileName, "Account passed to update() should be stale"); - assertEquals(1, account.getDynamoDbMigrationVersion()); + assertNotSame(updatedAccount, account); verify(accounts, times(1)).update(account); verifyNoMoreInteractions(accounts); - verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).update(account); + if (dynamoEnabled) { + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Account.class); + verify(accountsDynamoDb, times(1)).update(argumentCaptor.capture()); + assertEquals(uuid, argumentCaptor.getValue().getUuid()); + } else { + verify(accountsDynamoDb, never()).update(any()); + } + verify(accountsDynamoDb, dynamoEnabled ? times(1) : never()).get(uuid); verifyNoMoreInteractions(accountsDynamoDb); + + ArgumentCaptor redisSetArgumentCapture = ArgumentCaptor.forClass(String.class); + + verify(commands, times(4)).set(anyString(), redisSetArgumentCapture.capture()); + + Account firstAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(1), Account.class); + Account secondAccountCached = JsonHelpers.fromJson(redisSetArgumentCapture.getAllValues().get(3), Account.class); + + // uuid is @JsonIgnore, so we need to set it for compareAccounts to work + firstAccountCached.setUuid(uuid); + secondAccountCached.setUuid(uuid); + + assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(firstAccountCached), Optional.of(secondAccountCached))); } @Test - void testUpdate_dynamoConditionFailed() { + void testUpdate_dynamoMissing() { RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); Accounts accounts = mock(Accounts.class); @@ -382,25 +422,158 @@ class AccountsManagerTest { enableDynamo(true); when(commands.get(eq("Account3::" + uuid))).thenReturn(null); - doThrow(ConditionalCheckFailedException.class).when(accountsDynamoDb).update(any(Account.class)); + when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty()); + doAnswer(ACCOUNT_UPDATE_ANSWER).when(accounts).update(any()); + doAnswer(ACCOUNT_UPDATE_ANSWER).when(accountsDynamoDb).update(any()); AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager); - assertEquals(0, account.getDynamoDbMigrationVersion()); - - accountsManager.update(account); - - assertEquals(1, account.getDynamoDbMigrationVersion()); + Account updatedAccount = accountsManager.update(account, a -> {}); verify(accounts, times(1)).update(account); verifyNoMoreInteractions(accounts); - verify(accountsDynamoDb, times(1)).update(account); - verify(accountsDynamoDb, times(1)).create(account); + verify(accountsDynamoDb, never()).update(account); + verify(accountsDynamoDb, times(1)).get(uuid); + verifyNoMoreInteractions(accountsDynamoDb); + + assertEquals(1, updatedAccount.getVersion()); + } + + @Test + void testUpdate_optimisticLockingFailure() { + RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); + Accounts accounts = mock(Accounts.class); + AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class); + DeletedAccounts deletedAccounts = mock(DeletedAccounts.class); + DirectoryQueue directoryQueue = mock(DirectoryQueue.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); + MessagesManager messagesManager = mock(MessagesManager.class); + UsernamesManager usernamesManager = mock(UsernamesManager.class); + ProfilesManager profilesManager = mock(ProfilesManager.class); + SecureBackupClient secureBackupClient = mock(SecureBackupClient.class); + SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); + UUID uuid = UUID.randomUUID(); + Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]); + + enableDynamo(true); + + when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + + when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16]))); + doThrow(ContestedOptimisticLockException.class) + .doAnswer(ACCOUNT_UPDATE_ANSWER) + .when(accounts).update(any()); + + when(accountsDynamoDb.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16]))); + doThrow(ContestedOptimisticLockException.class) + .doAnswer(ACCOUNT_UPDATE_ANSWER) + .when(accountsDynamoDb).update(any()); + + AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager); + + account = accountsManager.update(account, a -> a.setProfileName("name")); + + assertEquals(1, account.getVersion()); + assertEquals("name", account.getProfileName()); + + verify(accounts, times(1)).get(uuid); + verify(accounts, times(2)).update(any()); + verifyNoMoreInteractions(accounts); + + // dynamo has an extra get() because the account is fetched before every update + verify(accountsDynamoDb, times(2)).get(uuid); + verify(accountsDynamoDb, times(2)).update(any()); verifyNoMoreInteractions(accountsDynamoDb); } + @Test + void testUpdate_dynamoOptimisticLockingFailureDuringCreate() { + RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); + Accounts accounts = mock(Accounts.class); + AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class); + DeletedAccounts deletedAccounts = mock(DeletedAccounts.class); + DirectoryQueue directoryQueue = mock(DirectoryQueue.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); + MessagesManager messagesManager = mock(MessagesManager.class); + UsernamesManager usernamesManager = mock(UsernamesManager.class); + ProfilesManager profilesManager = mock(ProfilesManager.class); + SecureBackupClient secureBackupClient = mock(SecureBackupClient.class); + SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); + UUID uuid = UUID.randomUUID(); + Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]); + + enableDynamo(true); + + when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + when(accountsDynamoDb.get(uuid)).thenReturn(Optional.empty()) + .thenReturn(Optional.of(account)); + when(accountsDynamoDb.create(any())).thenThrow(ContestedOptimisticLockException.class); + + AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager); + + accountsManager.update(account, a -> {}); + + verify(accounts, times(1)).update(account); + verifyNoMoreInteractions(accounts); + + verify(accountsDynamoDb, times(1)).get(uuid); + verifyNoMoreInteractions(accountsDynamoDb); + } + + @Test + void testUpdateDevice() throws Exception { + RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); + FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands); + Accounts accounts = mock(Accounts.class); + AccountsDynamoDb accountsDynamoDb = mock(AccountsDynamoDb.class); + DeletedAccounts deletedAccounts = mock(DeletedAccounts.class); + DirectoryQueue directoryQueue = mock(DirectoryQueue.class); + KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class); + MessagesManager messagesManager = mock(MessagesManager.class); + UsernamesManager usernamesManager = mock(UsernamesManager.class); + ProfilesManager profilesManager = mock(ProfilesManager.class); + SecureBackupClient secureBackupClient = mock(SecureBackupClient.class); + SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); + + AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccounts, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager); + + assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.empty(), Optional.empty())); + + final UUID uuid = UUID.randomUUID(); + Account account = new Account("+14152222222", uuid, new HashSet<>(), new byte[16]); + + when(accounts.get(uuid)).thenReturn(Optional.of(new Account("+14152222222", uuid, new HashSet<>(), new byte[16]))); + + assertTrue(account.getDevices().isEmpty()); + + Device enabledDevice = new Device(); + enabledDevice.setFetchesMessages(true); + enabledDevice.setSignedPreKey(new SignedPreKey(1L, "key", "signature")); + enabledDevice.setLastSeen(System.currentTimeMillis()); + final long deviceId = account.getNextDeviceId(); + enabledDevice.setId(deviceId); + account.addDevice(enabledDevice); + + @SuppressWarnings("unchecked") Consumer deviceUpdater = mock(Consumer.class); + @SuppressWarnings("unchecked") Consumer unknownDeviceUpdater = mock(Consumer.class); + + account = accountsManager.updateDevice(account, deviceId, deviceUpdater); + account = accountsManager.updateDevice(account, deviceId, d -> d.setName("deviceName")); + + assertEquals("deviceName", account.getDevice(deviceId).get().getName()); + + verify(deviceUpdater, times(1)).accept(any(Device.class)); + + accountsManager.updateDevice(account, account.getNextDeviceId(), unknownDeviceUpdater); + + verify(unknownDeviceUpdater, never()).accept(any(Device.class)); + } + + @Test void testCompareAccounts() throws Exception { RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); @@ -479,9 +652,11 @@ class AccountsManagerTest { assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2))); - a1.setDynamoDbMigrationVersion(1); + a1.setVersion(1); - assertEquals(Optional.empty(), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2))); + assertEquals(Optional.of("version"), accountsManager.compareAccounts(Optional.of(a1), Optional.of(a2))); + + a2.setVersion(1); a2.setProfileName("name"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java index 5957fb8eb..fafeff036 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/AccountsTest.java @@ -167,6 +167,12 @@ public class AccountsTest { accounts.update(account); + account.setProfileName("profileName"); + + accounts.update(account); + + assertThat(account.getVersion()).isEqualTo(2); + Optional retrieved = accounts.get("+14151112222"); assertThat(retrieved.isPresent()).isTrue(); @@ -359,6 +365,7 @@ public class AccountsTest { assertThat(result.getNumber()).isEqualTo(number); assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen()); assertThat(result.getUuid()).isEqualTo(uuid); + assertThat(result.getVersion()).isEqualTo(expecting.getVersion()); assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue(); for (Device expectingDevice : expecting.getDevices()) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java index 1a448d66e..35e0eddba 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/PushFeedbackProcessorTest.java @@ -5,16 +5,16 @@ package org.whispersystems.textsecuregcm.tests.storage; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor; -import org.whispersystems.textsecuregcm.util.Util; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; import java.util.Collections; import java.util.List; @@ -22,11 +22,20 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerRestartException; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; +import org.whispersystems.textsecuregcm.util.Util; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.*; - -public class PushFeedbackProcessorTest { +class PushFeedbackProcessorTest { private AccountsManager accountsManager = mock(AccountsManager.class); private DirectoryQueue directoryQueue = mock(DirectoryQueue.class); @@ -46,8 +55,10 @@ public class PushFeedbackProcessorTest { private Device stillActiveDevice = mock(Device.class); private Device undiscoverableDevice = mock(Device.class); - @Before - public void setup() { + @BeforeEach + void setup() { + AccountsHelper.setupMockUpdate(accountsManager); + when(uninstalledDevice.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2)); when(uninstalledDevice.getLastSeen()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(2)); when(uninstalledDeviceTwo.getUninstalledFeedbackTimestamp()).thenReturn(Util.todayInMillis() - TimeUnit.DAYS.toMillis(3)); @@ -85,7 +96,7 @@ public class PushFeedbackProcessorTest { @Test - public void testEmpty() throws AccountDatabaseCrawlerRestartException { + void testEmpty() throws AccountDatabaseCrawlerRestartException { PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue); processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), Collections.emptyList()); @@ -94,7 +105,7 @@ public class PushFeedbackProcessorTest { } @Test - public void testUpdate() throws AccountDatabaseCrawlerRestartException { + void testUpdate() throws AccountDatabaseCrawlerRestartException { PushFeedbackProcessor processor = new PushFeedbackProcessor(accountsManager, directoryQueue); processor.timeAndProcessCrawlChunk(Optional.of(UUID.randomUUID()), List.of(uninstalledAccount, mixedAccount, stillActiveAccount, freshAccount, cleanAccount, undiscoverableAccount)); @@ -102,7 +113,7 @@ public class PushFeedbackProcessorTest { verify(uninstalledDevice).setGcmId(isNull()); verify(uninstalledDevice).setFetchesMessages(eq(false)); - verify(accountsManager).update(eq(uninstalledAccount)); + verify(accountsManager).update(eq(uninstalledAccount), any()); verify(uninstalledDeviceTwo).setApnId(isNull()); verify(uninstalledDeviceTwo).setGcmId(isNull()); @@ -112,33 +123,35 @@ public class PushFeedbackProcessorTest { verify(installedDevice, never()).setGcmId(any()); verify(installedDevice, never()).setFetchesMessages(anyBoolean()); - verify(accountsManager).update(eq(mixedAccount)); + verify(accountsManager).update(eq(mixedAccount), any()); verify(recentUninstalledDevice, never()).setApnId(any()); verify(recentUninstalledDevice, never()).setGcmId(any()); verify(recentUninstalledDevice, never()).setFetchesMessages(anyBoolean()); - verify(accountsManager, never()).update(eq(freshAccount)); + verify(accountsManager, never()).update(eq(freshAccount), any()); verify(installedDeviceTwo, never()).setApnId(any()); verify(installedDeviceTwo, never()).setGcmId(any()); verify(installedDeviceTwo, never()).setFetchesMessages(anyBoolean()); - verify(accountsManager, never()).update(eq(cleanAccount)); + verify(accountsManager, never()).update(eq(cleanAccount), any()); verify(stillActiveDevice).setUninstalledFeedbackTimestamp(eq(0L)); verify(stillActiveDevice, never()).setApnId(any()); verify(stillActiveDevice, never()).setGcmId(any()); verify(stillActiveDevice, never()).setFetchesMessages(anyBoolean()); - verify(accountsManager).update(eq(stillActiveAccount)); + verify(accountsManager).update(eq(stillActiveAccount), any()); final ArgumentCaptor> refreshedAccountArgumentCaptor = ArgumentCaptor.forClass(List.class); verify(directoryQueue).refreshRegisteredUsers(refreshedAccountArgumentCaptor.capture()); - assertTrue(refreshedAccountArgumentCaptor.getValue().containsAll(List.of(undiscoverableAccount, uninstalledAccount))); + final List refreshedUuids = refreshedAccountArgumentCaptor.getValue().stream() + .map(Account::getUuid) + .collect(Collectors.toList()); + + assertTrue(refreshedUuids.containsAll(List.of(undiscoverableAccount.getUuid(), uninstalledAccount.getUuid()))); } - - } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java new file mode 100644 index 000000000..8f5287244 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -0,0 +1,148 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.tests.util; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.function.Consumer; +import org.mockito.MockingDetails; +import org.mockito.stubbing.Stubbing; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +public class AccountsHelper { + + public static void setupMockUpdate(final AccountsManager mockAccountsManager) { + when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { + final Account account = answer.getArgument(0, Account.class); + answer.getArgument(1, Consumer.class).accept(account); + + return copyAndMarkStale(account); + }); + + when(mockAccountsManager.updateDevice(any(), anyLong(), any())).thenAnswer(answer -> { + final Account account = answer.getArgument(0, Account.class); + final Long deviceId = answer.getArgument(1, Long.class); + account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class)); + + return copyAndMarkStale(account); + }); + } + + private static Account copyAndMarkStale(Account account) throws IOException { + MockingDetails mockingDetails = mockingDetails(account); + + final Account updatedAccount; + if (mockingDetails.isMock()) { + + updatedAccount = mock(Account.class); + + // it’s not possible to make `account` behave as if it were stale, because we use static mocks in AuthHelper + + for (Stubbing stubbing : mockingDetails.getStubbings()) { + switch (stubbing.getInvocation().getMethod().getName()) { + case "getUuid": { + when(updatedAccount.getUuid()).thenAnswer(stubbing); + break; + } + case "getNumber": { + when(updatedAccount.getNumber()).thenAnswer(stubbing); + break; + } + case "getDevices": { + when(updatedAccount.getDevices()) + .thenAnswer(stubbing); + break; + } + case "getDevice": { + when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0))) + .thenAnswer(stubbing); + break; + } + case "getMasterDevice": { + when(updatedAccount.getMasterDevice()).thenAnswer(stubbing); + break; + } + case "getAuthenticatedDevice": { + when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing); + break; + } + case "isEnabled": { + when(updatedAccount.isEnabled()).thenAnswer(stubbing); + break; + } + case "isDiscoverableByPhoneNumber": { + when(updatedAccount.isDiscoverableByPhoneNumber()).thenAnswer(stubbing); + break; + } + case "getNextDeviceId": { + when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing); + break; + } + case "isGroupsV2Supported": { + when(updatedAccount.isGroupsV2Supported()).thenAnswer(stubbing); + break; + } + case "isGv1MigrationSupported": { + when(updatedAccount.isGv1MigrationSupported()).thenAnswer(stubbing); + break; + } + case "isSenderKeySupported": { + when(updatedAccount.isSenderKeySupported()).thenAnswer(stubbing); + break; + } + case "isAnnouncementGroupSupported": { + when(updatedAccount.isAnnouncementGroupSupported()).thenAnswer(stubbing); + break; + } + case "getEnabledDeviceCount": { + when(updatedAccount.getEnabledDeviceCount()).thenAnswer(stubbing); + break; + } + case "getRelay": { + // TODO unused + when(updatedAccount.getRelay()).thenAnswer(stubbing); + break; + } + case "getRegistrationLock": { + when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing); + break; + } + case "getIdentityKey": { + when(updatedAccount.getIdentityKey()).thenAnswer(stubbing); + break; + } + default: { + throw new IllegalArgumentException( + "unsupported method: Account#" + stubbing.getInvocation().getMethod().getName()); + } + } + } + + } else { + final ObjectMapper mapper = SystemMapper.getMapper(); + updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class); + updatedAccount.setNumber(account.getNumber()); + account.markStale(); + } + + + return updatedAccount; + } + + public static Account eqUuid(Account value) { + return argThat(other -> other.getUuid().equals(value.getUuid())); + } + +}