Lock account when number owner lacks registration lock.

This commit is contained in:
erik-signal 2022-11-09 13:58:50 -05:00 committed by Erik Osheim
parent e6e6eb323d
commit 80a3a8a43c
9 changed files with 193 additions and 89 deletions

View File

@ -654,7 +654,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register( environment.jersey().register(
new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters, new AccountController(pendingAccountsManager, accountsManager, abusiveHostRules, rateLimiters,
registrationServiceClient, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), registrationServiceClient, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(),
recaptchaClient, pushNotificationManager, changeNumberManager, backupCredentialsGenerator)); recaptchaClient, pushNotificationManager, changeNumberManager, backupCredentialsGenerator,
clientPresenceManager, clock));
environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager)); environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager));

View File

@ -21,6 +21,20 @@ public class StoredRegistrationLock {
private final long lastSeen; private final long lastSeen;
/**
* @return milliseconds since the last time the account was seen.
*/
private long timeSinceLastSeen() {
return System.currentTimeMillis() - lastSeen;
}
/**
* @return true if the registration lock and salt are both set.
*/
private boolean hasLockAndSalt() {
return registrationLock.isPresent() && registrationLockSalt.isPresent();
}
public StoredRegistrationLock(Optional<String> registrationLock, Optional<String> registrationLockSalt, long lastSeen) { public StoredRegistrationLock(Optional<String> registrationLock, Optional<String> registrationLockSalt, long lastSeen) {
this.registrationLock = registrationLock; this.registrationLock = registrationLock;
this.registrationLockSalt = registrationLockSalt; this.registrationLockSalt = registrationLockSalt;
@ -28,24 +42,22 @@ public class StoredRegistrationLock {
} }
public boolean requiresClientRegistrationLock() { public boolean requiresClientRegistrationLock() {
return registrationLock.isPresent() && registrationLockSalt.isPresent() && System.currentTimeMillis() - lastSeen < TimeUnit.DAYS.toMillis(7); boolean hasTimeRemaining = getTimeRemaining() >= 0;
return hasLockAndSalt() && hasTimeRemaining;
} }
public boolean needsFailureCredentials() { public boolean needsFailureCredentials() {
return registrationLock.isPresent() && registrationLockSalt.isPresent(); return hasLockAndSalt();
} }
public long getTimeRemaining() { public long getTimeRemaining() {
return TimeUnit.DAYS.toMillis(7) - (System.currentTimeMillis() - lastSeen); return TimeUnit.DAYS.toMillis(7) - timeSinceLastSeen();
} }
public boolean verify(@Nullable String clientRegistrationLock) { public boolean verify(@Nullable String clientRegistrationLock) {
if (Util.isEmpty(clientRegistrationLock)) { if (hasLockAndSalt() && Util.nonEmpty(clientRegistrationLock)) {
return false; AuthenticationCredentials credentials = new AuthenticationCredentials(registrationLock.get(), registrationLockSalt.get());
} return credentials.verify(clientRegistrationLock);
if (registrationLock.isPresent() && registrationLockSalt.isPresent() && !Util.isEmpty(clientRegistrationLock)) {
return new AuthenticationCredentials(registrationLock.get(), registrationLockSalt.get()).verify(clientRegistrationLock);
} else { } else {
return false; return false;
} }

View File

@ -19,8 +19,10 @@ import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -82,6 +84,7 @@ import org.whispersystems.textsecuregcm.entities.UsernameResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotification; import org.whispersystems.textsecuregcm.push.PushNotification;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient; import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
@ -154,35 +157,64 @@ public class AccountController {
private final ExternalServiceCredentialGenerator backupServiceCredentialGenerator; private final ExternalServiceCredentialGenerator backupServiceCredentialGenerator;
private final ChangeNumberManager changeNumberManager; private final ChangeNumberManager changeNumberManager;
private final Clock clock;
private final ClientPresenceManager clientPresenceManager;
@VisibleForTesting @VisibleForTesting
static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15); static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
public AccountController(StoredVerificationCodeManager pendingAccounts, public AccountController(
AccountsManager accounts, StoredVerificationCodeManager pendingAccounts,
AbusiveHostRules abusiveHostRules, AccountsManager accounts,
RateLimiters rateLimiters, AbusiveHostRules abusiveHostRules,
RegistrationServiceClient registrationServiceClient, RateLimiters rateLimiters,
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager, RegistrationServiceClient registrationServiceClient,
TurnTokenGenerator turnTokenGenerator, DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
Map<String, Integer> testDevices, TurnTokenGenerator turnTokenGenerator,
RecaptchaClient recaptchaClient, Map<String, Integer> testDevices,
PushNotificationManager pushNotificationManager, RecaptchaClient recaptchaClient,
ChangeNumberManager changeNumberManager, PushNotificationManager pushNotificationManager,
ExternalServiceCredentialGenerator backupServiceCredentialGenerator) ChangeNumberManager changeNumberManager,
{ ExternalServiceCredentialGenerator backupServiceCredentialGenerator,
this.pendingAccounts = pendingAccounts; ClientPresenceManager clientPresenceManager,
this.accounts = accounts; Clock clock
this.abusiveHostRules = abusiveHostRules; ) {
this.rateLimiters = rateLimiters; this.pendingAccounts = pendingAccounts;
this.registrationServiceClient = registrationServiceClient; this.accounts = accounts;
this.dynamicConfigurationManager = dynamicConfigurationManager; this.abusiveHostRules = abusiveHostRules;
this.testDevices = testDevices; this.rateLimiters = rateLimiters;
this.turnTokenGenerator = turnTokenGenerator; this.registrationServiceClient = registrationServiceClient;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.testDevices = testDevices;
this.turnTokenGenerator = turnTokenGenerator;
this.recaptchaClient = recaptchaClient; this.recaptchaClient = recaptchaClient;
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.backupServiceCredentialGenerator = backupServiceCredentialGenerator; this.backupServiceCredentialGenerator = backupServiceCredentialGenerator;
this.changeNumberManager = changeNumberManager; this.changeNumberManager = changeNumberManager;
this.clientPresenceManager = clientPresenceManager;
this.clock = clock;
}
@VisibleForTesting
public AccountController(
StoredVerificationCodeManager pendingAccounts,
AccountsManager accounts,
AbusiveHostRules abusiveHostRules,
RateLimiters rateLimiters,
RegistrationServiceClient registrationServiceClient,
DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
TurnTokenGenerator turnTokenGenerator,
Map<String, Integer> testDevices,
RecaptchaClient recaptchaClient,
PushNotificationManager pushNotificationManager,
ChangeNumberManager changeNumberManager,
ExternalServiceCredentialGenerator backupServiceCredentialGenerator
) {
this(pendingAccounts, accounts, abusiveHostRules, rateLimiters,
registrationServiceClient, dynamicConfigurationManager, turnTokenGenerator, testDevices, recaptchaClient,
pushNotificationManager, changeNumberManager,
backupServiceCredentialGenerator, null, Clock.systemUTC());
} }
@Timed @Timed
@ -205,7 +237,7 @@ public class AccountController {
String pushChallenge = generatePushChallenge(); String pushChallenge = generatePushChallenge();
StoredVerificationCode storedVerificationCode = StoredVerificationCode storedVerificationCode =
new StoredVerificationCode(null, System.currentTimeMillis(), pushChallenge, null, null); new StoredVerificationCode(null, clock.millis(), pushChallenge, null, null);
pendingAccounts.store(number, storedVerificationCode); pendingAccounts.store(number, storedVerificationCode);
pushNotificationManager.sendRegistrationChallengeNotification(pushToken, tokenType, storedVerificationCode.pushCode()); pushNotificationManager.sendRegistrationChallengeNotification(pushToken, tokenType, storedVerificationCode.pushCode());
@ -310,7 +342,7 @@ public class AccountController {
messageTransport, clientType, acceptLanguage.orElse(null), REGISTRATION_RPC_TIMEOUT).join(); messageTransport, clientType, acceptLanguage.orElse(null), REGISTRATION_RPC_TIMEOUT).join();
final StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null, final StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null,
System.currentTimeMillis(), clock.millis(),
maybeStoredVerificationCode.map(StoredVerificationCode::pushCode).orElse(null), maybeStoredVerificationCode.map(StoredVerificationCode::pushCode).orElse(null),
null, null,
sessionId); sessionId);
@ -777,6 +809,12 @@ public class AccountController {
} }
if (!existingRegistrationLock.verify(clientRegistrationLock)) { if (!existingRegistrationLock.verify(clientRegistrationLock)) {
// At this point, the client verified ownership of the phone number but doesnt have the reglock PIN.
// Freezing the existing account credentials will definitively start the reglock timeout. Until the timeout, the current reglock can still be supplied,
// along with phone number verification, to restore access.
accounts.update(existingAccount, Account::lockAuthenticationCredentials);
List<Long> deviceIds = existingAccount.getDevices().stream().map(Device::getId).toList();
clientPresenceManager.disconnectAllPresences(existingAccount.getUuid(), deviceIds);
throw new WebApplicationException(Response.status(423) throw new WebApplicationException(Response.status(423)
.entity(new RegistrationLockFailure(existingRegistrationLock.getTimeRemaining(), .entity(new RegistrationLockFailure(existingRegistrationLock.getTimeRemaining(),
existingRegistrationLock.needsFailureCredentials() ? existingBackupCredentials : null)) existingRegistrationLock.needsFailureCredentials() ? existingBackupCredentials : null))

View File

@ -13,6 +13,8 @@ import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed; import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.LettuceFutures;
import io.lettuce.core.RedisFuture;
import io.lettuce.core.ScriptOutputType; import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
@ -21,6 +23,7 @@ import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
@ -178,16 +181,25 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
List.of(managerId, String.valueOf(PRESENCE_EXPIRATION_SECONDS))); List.of(managerId, String.valueOf(PRESENCE_EXPIRATION_SECONDS)));
} }
public void disconnectAllPresences(final UUID accountUuid, final List<Long> deviceIds) {
List<String> presenceKeys = new ArrayList<>();
deviceIds.forEach(deviceId -> {
String presenceKey = getPresenceKey(accountUuid, deviceId);
if (isLocallyPresent(accountUuid, deviceId)) {
displacePresence(presenceKey, false);
}
presenceKeys.add(presenceKey);
});
presenceCluster.useCluster(connection -> {
List<RedisFuture<Long>> futures = presenceKeys.stream().map(key -> connection.async().del(key)).toList();
LettuceFutures.awaitAll(connection.getTimeout(), futures.toArray(new RedisFuture[0]));
});
}
public void disconnectPresence(final UUID accountUuid, final long deviceId) { public void disconnectPresence(final UUID accountUuid, final long deviceId) {
final String presenceKey = getPresenceKey(accountUuid, deviceId); disconnectAllPresences(accountUuid, List.of(deviceId));
if (isLocallyPresent(accountUuid, deviceId)) {
displacePresence(presenceKey, false);
}
// If connected locally, we still need to clean up the presence key.
// If connected remotely, the other server will get a keyspace message and handle the disconnect
presenceCluster.useCluster(connection -> connection.sync().del(presenceKey));
} }
private void displacePresence(final String presenceKey, final boolean connectedElsewhere) { private void displacePresence(final String presenceKey, final boolean connectedElsewhere) {

View File

@ -85,6 +85,7 @@ public class Account {
@JsonIgnore @JsonIgnore
private boolean canonicallyDiscoverable; private boolean canonicallyDiscoverable;
public UUID getUuid() { public UUID getUuid() {
// this is the one method that may be called on a stale account // this is the one method that may be called on a stale account
return uuid; return uuid;
@ -304,16 +305,10 @@ public class Account {
public long getLastSeen() { public long getLastSeen() {
requireNotStale(); requireNotStale();
return devices.stream()
long lastSeen = 0; .map(Device::getLastSeen)
.max(Long::compare)
for (Device device : devices) { .orElse(0L);
if (device.getLastSeen() > lastSeen) {
lastSeen = device.getLastSeen();
}
}
return lastSeen;
} }
public Optional<String> getCurrentProfileVersion() { public Optional<String> getCurrentProfileVersion() {
@ -344,7 +339,6 @@ public class Account {
public void addBadge(Clock clock, AccountBadge badge) { public void addBadge(Clock clock, AccountBadge badge) {
requireNotStale(); requireNotStale();
boolean added = false; boolean added = false;
for (int i = 0; i < badges.size(); i++) { for (int i = 0; i < badges.size(); i++) {
AccountBadge badgeInList = badges.get(i); AccountBadge badgeInList = badges.get(i);
@ -478,6 +472,19 @@ public class Account {
this.version = version; this.version = version;
} }
/**
* Lock account by invalidating authentication tokens.
*
* We only want to do this in cases where there is a potential conflict between the
* phone number holder and the registration lock holder. In that case, locking the
* account will ensure that either the registration lock holder proves ownership
* of the phone number, or after 7 days the phone number holder can register a new
* account.
*/
public void lockAuthenticationCredentials() {
devices.forEach(Device::lockAuthenticationCredentials);
}
boolean isStale() { boolean isStale() {
return stale; return stale;
} }

View File

@ -149,6 +149,20 @@ public class Device {
this.salt = credentials.getSalt(); this.salt = credentials.getSalt();
} }
/**
* Lock device by invalidating authentication tokens.
*
* This should only be used from Account::lockAuthenticationCredentials.
*
* See that method for more information.
*/
public void lockAuthenticationCredentials() {
AuthenticationCredentials oldAuth = getAuthenticationCredentials();
String token = "!" + oldAuth.getHashedAuthenticationToken();
String salt = oldAuth.getSalt();
setAuthenticationCredentials(new AuthenticationCredentials(token, salt));
}
public AuthenticationCredentials getAuthenticationCredentials() { public AuthenticationCredentials getAuthenticationCredentials() {
return new AuthenticationCredentials(authToken, salt); return new AuthenticationCredentials(authToken, salt);
} }

View File

@ -101,6 +101,10 @@ public class Util {
return param == null || param.length() == 0; return param == null || param.length() == 0;
} }
public static boolean nonEmpty(String param) {
return !isEmpty(param);
}
public static byte[] truncate(byte[] element, int length) { public static byte[] truncate(byte[] element, int length) {
byte[] result = new byte[length]; byte[] result = new byte[length];
System.arraycopy(element, 0, result, 0, result.length); System.arraycopy(element, 0, result, 0, result.length);
@ -138,15 +142,11 @@ public class Util {
return parts; return parts;
} }
public static int toIntExact(long value) { public static final long DAY_IN_MILLIS = 86400000L;
if ((int) value != value) { public static final long WEEK_IN_MILLIS = DAY_IN_MILLIS * 7;
throw new ArithmeticException("integer overflow");
}
return (int) value;
}
public static int currentDaysSinceEpoch(@Nonnull Clock clock) { public static int currentDaysSinceEpoch(@Nonnull Clock clock) {
return toIntExact(clock.millis() / 1000 / 60/ 60 / 24); return Math.toIntExact(clock.millis() / DAY_IN_MILLIS);
} }
public static void sleep(long i) { public static void sleep(long i) {
@ -180,12 +180,12 @@ public class Util {
} }
public static long todayInMillis(Clock clock) { public static long todayInMillis(Clock clock) {
return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(clock.instant().toEpochMilli())); return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(clock.millis()));
} }
public static long todayInMillisGivenOffsetFromNow(Clock clock, Duration offset) { public static long todayInMillisGivenOffsetFromNow(Clock clock, Duration offset) {
final long currentTimeSeconds = offset.addTo(clock.instant()).getLong(ChronoField.INSTANT_SECONDS); final long ms = offset.toMillis() + clock.millis();
return TimeUnit.DAYS.toMillis(TimeUnit.SECONDS.toDays(currentTimeSeconds)); return TimeUnit.DAYS.toMillis(TimeUnit.MILLISECONDS.toDays(ms));
} }
public static Optional<String> findBestLocale(List<LanguageRange> priorityList, Collection<String> supportedLocales) { public static Optional<String> findBestLocale(List<LanguageRange> priorityList, Collection<String> supportedLocales) {

View File

@ -13,6 +13,7 @@ import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
@ -22,7 +23,6 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -34,14 +34,12 @@ import com.google.i18n.phonenumbers.Phonenumber;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
@ -95,6 +93,7 @@ import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMa
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper; import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberResponse; import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberResponse;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotification; import org.whispersystems.textsecuregcm.push.PushNotification;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient; import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
@ -114,6 +113,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Hex; import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestClock;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class AccountControllerTest { class AccountControllerTest {
@ -146,26 +146,28 @@ class AccountControllerTest {
private static final String TEST_NUMBER = "+14151111113"; private static final String TEST_NUMBER = "+14151111113";
private static StoredVerificationCodeManager pendingAccountsManager = mock(StoredVerificationCodeManager.class); private static StoredVerificationCodeManager pendingAccountsManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class); private static AccountsManager accountsManager = mock(AccountsManager.class);
private static AbusiveHostRules abusiveHostRules = mock(AbusiveHostRules.class); private static AbusiveHostRules abusiveHostRules = mock(AbusiveHostRules.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class); private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static RateLimiter pinLimiter = mock(RateLimiter.class); private static RateLimiter pinLimiter = mock(RateLimiter.class);
private static RateLimiter smsVoiceIpLimiter = mock(RateLimiter.class); private static RateLimiter smsVoiceIpLimiter = mock(RateLimiter.class);
private static RateLimiter smsVoicePrefixLimiter = mock(RateLimiter.class); private static RateLimiter smsVoicePrefixLimiter = mock(RateLimiter.class);
private static RateLimiter autoBlockLimiter = mock(RateLimiter.class); private static RateLimiter autoBlockLimiter = mock(RateLimiter.class);
private static RateLimiter usernameSetLimiter = mock(RateLimiter.class); private static RateLimiter usernameSetLimiter = mock(RateLimiter.class);
private static RateLimiter usernameReserveLimiter = mock(RateLimiter.class); private static RateLimiter usernameReserveLimiter = mock(RateLimiter.class);
private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class); private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class);
private static RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private static RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class); private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private static Account senderPinAccount = mock(Account.class); private static Account senderPinAccount = mock(Account.class);
private static Account senderRegLockAccount = mock(Account.class); private static Account senderRegLockAccount = mock(Account.class);
private static Account senderHasStorage = mock(Account.class); private static Account senderHasStorage = mock(Account.class);
private static Account senderTransfer = mock(Account.class); private static Account senderTransfer = mock(Account.class);
private static RecaptchaClient recaptchaClient = mock(RecaptchaClient.class); private static RecaptchaClient recaptchaClient = mock(RecaptchaClient.class);
private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class); private static ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
private static ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private static TestClock testClock = TestClock.now();
private static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); private static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -194,7 +196,9 @@ class AccountControllerTest {
recaptchaClient, recaptchaClient,
pushNotificationManager, pushNotificationManager,
changeNumberManager, changeNumberManager,
storageCredentialGenerator)) storageCredentialGenerator,
clientPresenceManager,
testClock))
.build(); .build();
@ -340,7 +344,8 @@ class AccountControllerTest {
senderTransfer, senderTransfer,
recaptchaClient, recaptchaClient,
pushNotificationManager, pushNotificationManager,
changeNumberManager); changeNumberManager,
clientPresenceManager);
clearInvocations(AuthHelper.DISABLED_DEVICE); clearInvocations(AuthHelper.DISABLED_DEVICE);
} }
@ -1063,6 +1068,8 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(423); assertThat(response.getStatus()).isEqualTo(423);
verify(senderRegLockAccount).lockAuthenticationCredentials();
verify(clientPresenceManager, times(1)).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any());
verify(pinLimiter).validate(eq(SENDER_REG_LOCK)); verify(pinLimiter).validate(eq(SENDER_REG_LOCK));
} }
@ -1085,6 +1092,8 @@ class AccountControllerTest {
assertThat(failure.getBackupCredentials().getPassword().startsWith(SENDER_REG_LOCK_UUID.toString())).isTrue(); assertThat(failure.getBackupCredentials().getPassword().startsWith(SENDER_REG_LOCK_UUID.toString())).isTrue();
assertThat(failure.getTimeRemaining()).isGreaterThan(0); assertThat(failure.getTimeRemaining()).isGreaterThan(0);
verify(senderRegLockAccount).lockAuthenticationCredentials();
verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any());
verifyNoInteractions(pinLimiter); verifyNoInteractions(pinLimiter);
} }
@ -1311,9 +1320,10 @@ class AccountControllerTest {
final StoredRegistrationLock existingRegistrationLock = mock(StoredRegistrationLock.class); final StoredRegistrationLock existingRegistrationLock = mock(StoredRegistrationLock.class);
when(existingRegistrationLock.requiresClientRegistrationLock()).thenReturn(true); when(existingRegistrationLock.requiresClientRegistrationLock()).thenReturn(true);
final UUID existingUuid = UUID.randomUUID();
final Account existingAccount = mock(Account.class); final Account existingAccount = mock(Account.class);
when(existingAccount.getNumber()).thenReturn(number); when(existingAccount.getNumber()).thenReturn(number);
when(existingAccount.getUuid()).thenReturn(UUID.randomUUID()); when(existingAccount.getUuid()).thenReturn(existingUuid);
when(existingAccount.getRegistrationLock()).thenReturn(existingRegistrationLock); when(existingAccount.getRegistrationLock()).thenReturn(existingRegistrationLock);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(existingAccount)); when(accountsManager.getByE164(number)).thenReturn(Optional.of(existingAccount));
@ -1327,6 +1337,9 @@ class AccountControllerTest {
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423); assertThat(response.getStatus()).isEqualTo(423);
verify(existingAccount).lockAuthenticationCredentials();
verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
} }
@ -1347,9 +1360,10 @@ class AccountControllerTest {
when(existingRegistrationLock.requiresClientRegistrationLock()).thenReturn(true); when(existingRegistrationLock.requiresClientRegistrationLock()).thenReturn(true);
when(existingRegistrationLock.verify(anyString())).thenReturn(false); when(existingRegistrationLock.verify(anyString())).thenReturn(false);
UUID existingUuid = UUID.randomUUID();
final Account existingAccount = mock(Account.class); final Account existingAccount = mock(Account.class);
when(existingAccount.getNumber()).thenReturn(number); when(existingAccount.getNumber()).thenReturn(number);
when(existingAccount.getUuid()).thenReturn(UUID.randomUUID()); when(existingAccount.getUuid()).thenReturn(existingUuid);
when(existingAccount.getRegistrationLock()).thenReturn(existingRegistrationLock); when(existingAccount.getRegistrationLock()).thenReturn(existingRegistrationLock);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(existingAccount)); when(accountsManager.getByE164(number)).thenReturn(Optional.of(existingAccount));
@ -1363,6 +1377,9 @@ class AccountControllerTest {
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423); assertThat(response.getStatus()).isEqualTo(423);
verify(existingAccount).lockAuthenticationCredentials();
verify(clientPresenceManager, atLeastOnce()).disconnectAllPresences(eq(existingUuid), any());
verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any()); verify(changeNumberManager, never()).changeNumber(any(), any(), any(), any(), any(), any());
} }
@ -1399,6 +1416,8 @@ class AccountControllerTest {
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getStatus()).isEqualTo(200);
verify(senderRegLockAccount, never()).lockAuthenticationCredentials();
verify(clientPresenceManager, never()).disconnectAllPresences(eq(SENDER_REG_LOCK_UUID), any());
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any()); verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), any(), any(), any(), any(), any());
} }

View File

@ -132,6 +132,7 @@ public class AccountsHelper {
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing); case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
case "getIdentityKey" -> when(updatedAccount.getIdentityKey()).thenAnswer(stubbing); case "getIdentityKey" -> when(updatedAccount.getIdentityKey()).thenAnswer(stubbing);
case "getBadges" -> when(updatedAccount.getBadges()).thenAnswer(stubbing); case "getBadges" -> when(updatedAccount.getBadges()).thenAnswer(stubbing);
case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing);
default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName()); default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
} }
} }