diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 8eec8a668..69434b0d2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -756,8 +756,8 @@ public class WhisperServerService extends Application maxDeviceConfiguration; private final Clock clock; @@ -94,6 +101,7 @@ public class DeviceController { MessagesManager messages, KeysManager keys, RateLimiters rateLimiters, + FaultTolerantRedisCluster usedTokenCluster, Map maxDeviceConfiguration, final Clock clock) { this.pendingDevices = pendingDevices; this.verificationTokenKey = new SecretKeySpec(linkDeviceSecret, VERIFICATION_TOKEN_ALGORITHM); @@ -101,6 +109,7 @@ public class DeviceController { this.messages = messages; this.keys = keys; this.rateLimiters = rateLimiters; + this.usedTokenCluster = usedTokenCluster; this.maxDeviceConfiguration = maxDeviceConfiguration; this.clock = clock; @@ -298,6 +307,13 @@ public class DeviceController { @VisibleForTesting Optional checkVerificationToken(final String verificationToken) { + final boolean tokenUsed = usedTokenCluster.withCluster(connection -> + connection.sync().get(getUsedTokenKey(verificationToken)) != null); + + if (tokenUsed) { + return Optional.empty(); + } + final String[] claimsAndSignature = verificationToken.split(":", 2); if (claimsAndSignature.length != 2) { @@ -371,8 +387,9 @@ public class DeviceController { rateLimiters.getVerifyDeviceLimiter().validate(phoneNumber); - final Account account = checkVerificationToken(verificationCode) - .flatMap(accounts::getByAccountIdentifier) + final Optional maybeAciFromToken = checkVerificationToken(verificationCode); + + final Account account = maybeAciFromToken.flatMap(accounts::getByAccountIdentifier) .or(() -> { final boolean verificationCodeValid = pendingDevices.getCodeForNumber(phoneNumber) .map(storedVerificationCode -> storedVerificationCode.isValid(verificationCode)) @@ -468,6 +485,15 @@ public class DeviceController { pendingDevices.remove(phoneNumber); + if (maybeAciFromToken.isPresent()) { + usedTokenCluster.useCluster(connection -> + connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))); + } + return new Pair<>(updatedAccount, device); } + + private static String getUsedTokenKey(final String token) { + return "usedToken::" + token; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 9872649b2..23bfe3788 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -4,40 +4,12 @@ */ package org.whispersystems.textsecuregcm.controllers; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.Mockito.clearInvocations; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - import com.google.common.collect.ImmutableSet; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; -import java.nio.charset.StandardCharsets; -import java.security.SecureRandom; -import java.time.Clock; -import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; -import javax.ws.rs.Path; -import javax.ws.rs.client.Entity; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -55,31 +27,40 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; -import org.whispersystems.textsecuregcm.entities.AccountAttributes; -import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; -import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; -import org.whispersystems.textsecuregcm.entities.DeviceResponse; -import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; -import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; -import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; -import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; +import org.whispersystems.textsecuregcm.entities.*; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.storage.*; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; -import org.whispersystems.textsecuregcm.storage.KeysManager; -import org.whispersystems.textsecuregcm.storage.MessagesManager; -import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.VerificationCode; +import javax.ws.rs.Path; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.time.Clock; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.*; + @ExtendWith(DropwizardExtensionsSupport.class) class DeviceControllerTest { @@ -92,10 +73,11 @@ class DeviceControllerTest { MessagesManager messages, KeysManager keys, RateLimiters rateLimiters, + FaultTolerantRedisCluster usedTokenCluster, Map deviceConfiguration, Clock clock) { - super(pendingDevices, linkDeviceSecret, accounts, messages, keys, rateLimiters, deviceConfiguration, clock); + super(pendingDevices, linkDeviceSecret, accounts, messages, keys, rateLimiters, usedTokenCluster, deviceConfiguration, clock); } @Override @@ -110,6 +92,7 @@ class DeviceControllerTest { private static KeysManager keysManager = mock(KeysManager.class); private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiter rateLimiter = mock(RateLimiter.class); + private static RedisAdvancedClusterCommands commands = mock(RedisAdvancedClusterCommands.class); private static Account account = mock(Account.class); private static Account maxedAccount = mock(Account.class); private static Device masterDevice = mock(Device.class); @@ -123,6 +106,7 @@ class DeviceControllerTest { messagesManager, keysManager, rateLimiters, + RedisClusterHelper.builder().stringCommands(commands).build(), deviceConfiguration, testClock); @@ -187,6 +171,7 @@ class DeviceControllerTest { keysManager, rateLimiters, rateLimiter, + commands, account, maxedAccount, masterDevice, @@ -226,6 +211,7 @@ class DeviceControllerTest { verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID); + verify(commands, never()).set(anyString(), anyString(), any()); } @Test @@ -248,6 +234,31 @@ class DeviceControllerTest { DeviceResponse.class); assertThat(response.getDeviceId()).isEqualTo(42L); + + verify(commands).set(anyString(), anyString(), any()); + } + + @Test + void validDeviceRegisterTestSignedTokenUsed() { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID); + + when(commands.get(anyString())).thenReturn(""); + + final Response response = resources.getJerseyTest() + .target("/v1/devices/" + verificationToken) + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(new AccountAttributes(false, 1234, null, + null, true, null), + MediaType.APPLICATION_JSON_TYPE)); + + assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); } @Test @@ -368,8 +379,10 @@ class DeviceControllerTest { verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); + verify(commands, never()).set(anyString(), anyString(), any()); } + @ParameterizedTest @MethodSource("linkDeviceAtomic") @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -416,6 +429,8 @@ class DeviceControllerTest { .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class); assertThat(response.getDeviceId()).isEqualTo(42L); + + verify(commands).set(anyString(), anyString(), any()); } private static Stream linkDeviceAtomic() { @@ -431,6 +446,50 @@ class DeviceControllerTest { ); } + @Test + void linkDeviceAtomicWithVerificationTokenUsed() { + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); + + final Device existingDevice = mock(Device.class); + when(existingDevice.getId()).thenReturn(Device.MASTER_ID); + when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + + final Optional aciSignedPreKey; + final Optional pniSignedPreKey; + final Optional aciPqLastResortPreKey; + final Optional pniPqLastResortPreKey; + + final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); + + aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair)); + pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair)); + pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + + when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); + when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); + + when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + when(commands.get(anyString())).thenReturn(""); + + final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), + new AccountAttributes(false, 1234, null, null, true, null), + new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/link") + .request() + .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + } + } + @ParameterizedTest @MethodSource @SuppressWarnings("OptionalUsedAsFieldOrParameterType")