Accept signed tokens in addition to randomly-generated codes for authorizing device linking

This commit is contained in:
Jon Chambers 2023-08-03 11:35:34 -04:00 committed by Chris Eager
parent 48c7572dd5
commit 308da3343d
7 changed files with 294 additions and 19 deletions

View File

@ -84,3 +84,5 @@ currentReportingKey.secret: AAAAAAAAAAA=
currentReportingKey.salt: AAAAAAAAAAA=
turn.secret: AAAAAAAAAAA=
linkDevice.secret: AAAAAAAAAAA=

View File

@ -435,3 +435,6 @@ turn:
commandStopListener:
path: /example/path
linkDevice:
secret: secret://linkDevice.secret

View File

@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.configuration.FcmConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguration;
import org.whispersystems.textsecuregcm.configuration.GenericZkConfig;
import org.whispersystems.textsecuregcm.configuration.HCaptchaConfiguration;
import org.whispersystems.textsecuregcm.configuration.LinkDeviceSecretConfiguration;
import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
@ -300,6 +301,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private CommandStopListenerConfiguration commandStopListener;
@Valid
@NotNull
@JsonProperty
private LinkDeviceSecretConfiguration linkDevice;
public AdminEventLoggingConfiguration getAdminEventLoggingConfiguration() {
return adminEventLoggingConfiguration;
}
@ -498,4 +504,8 @@ public class WhisperServerConfiguration extends Configuration {
public CommandStopListenerConfiguration getCommandStopListener() {
return commandStopListener;
}
public LinkDeviceSecretConfiguration getLinkDeviceSecretConfiguration() {
return linkDevice;
}
}

View File

@ -756,7 +756,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new CallLinkController(rateLimiters, genericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, genericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices()),
new DeviceController(pendingDevicesManager, config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices(),
clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new),

View File

@ -0,0 +1,11 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
public record LinkDeviceSecretConfiguration(SecretBytes secret) {
}

View File

@ -12,12 +12,20 @@ import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.security.SecureRandom;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@ -66,24 +74,45 @@ public class DeviceController {
static final int MAX_DEVICES = 6;
private final StoredVerificationCodeManager pendingDevices;
private final Key verificationTokenKey;
private final AccountsManager accounts;
private final MessagesManager messages;
private final KeysManager keys;
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
private final Clock clock;
private static final String VERIFICATION_TOKEN_ALGORITHM = "HmacSHA256";
@VisibleForTesting
static final Duration TOKEN_EXPIRATION_DURATION = Duration.ofMinutes(10);
public DeviceController(StoredVerificationCodeManager pendingDevices,
byte[] linkDeviceSecret,
AccountsManager accounts,
MessagesManager messages,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> maxDeviceConfiguration) {
Map<String, Integer> maxDeviceConfiguration, final Clock clock) {
this.pendingDevices = pendingDevices;
this.verificationTokenKey = new SecretKeySpec(linkDeviceSecret, VERIFICATION_TOKEN_ALGORITHM);
this.accounts = accounts;
this.messages = messages;
this.keys = keys;
this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration;
this.clock = clock;
// Fail fast: reject bad keys
try {
final Mac mac = Mac.getInstance(VERIFICATION_TOKEN_ALGORITHM);
mac.init(verificationTokenKey);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("All Java implementations must support HmacSHA256", e);
} catch (final InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
}
@Timed
@ -239,12 +268,86 @@ public class DeviceController {
accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities));
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {
@VisibleForTesting
VerificationCode generateVerificationCode() {
SecureRandom random = new SecureRandom();
int randomInt = 100000 + random.nextInt(900000);
return new VerificationCode(randomInt);
}
private Mac getInitializedMac() {
try {
final Mac mac = Mac.getInstance(VERIFICATION_TOKEN_ALGORITHM);
mac.init(verificationTokenKey);
return mac;
} catch (final NoSuchAlgorithmException | InvalidKeyException e) {
// All Java implementations must support HmacSHA256 and we checked the key at construction time, so this can never
// happen
throw new AssertionError(e);
}
}
@VisibleForTesting
String generateVerificationToken(final UUID aci) {
final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8));
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
}
@VisibleForTesting
Optional<UUID> checkVerificationToken(final String verificationToken) {
final String[] claimsAndSignature = verificationToken.split(":", 2);
if (claimsAndSignature.length != 2) {
return Optional.empty();
}
final byte[] expectedSignature = getInitializedMac().doFinal(claimsAndSignature[0].getBytes(StandardCharsets.UTF_8));
final byte[] providedSignature;
try {
providedSignature = Base64.getUrlDecoder().decode(claimsAndSignature[1]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
if (!MessageDigest.isEqual(expectedSignature, providedSignature)) {
return Optional.empty();
}
final String[] aciAndTimestamp = claimsAndSignature[0].split("\\.", 2);
if (aciAndTimestamp.length != 2) {
return Optional.empty();
}
final UUID aci;
try {
aci = UUID.fromString(aciAndTimestamp[0]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
final Instant timestamp;
try {
timestamp = Instant.ofEpochMilli(Long.parseLong(aciAndTimestamp[1]));
} catch (final NumberFormatException e) {
return Optional.empty();
}
final Instant tokenExpiration = timestamp.plus(TOKEN_EXPIRATION_DURATION);
if (tokenExpiration.isBefore(clock.instant())) {
return Optional.empty();
}
return Optional.of(aci);
}
static boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities) {
boolean isDowngrade = false;
@ -268,13 +371,15 @@ public class DeviceController {
rateLimiters.getVerifyDeviceLimiter().validate(phoneNumber);
Optional<StoredVerificationCode> storedVerificationCode = pendingDevices.getCodeForNumber(phoneNumber);
final Account account = checkVerificationToken(verificationCode)
.flatMap(accounts::getByAccountIdentifier)
.or(() -> {
final boolean verificationCodeValid = pendingDevices.getCodeForNumber(phoneNumber)
.map(storedVerificationCode -> storedVerificationCode.isValid(verificationCode))
.orElse(false);
if (storedVerificationCode.isEmpty() || !storedVerificationCode.get().isValid(verificationCode)) {
throw new WebApplicationException(Response.status(403).build());
}
final Account account = accounts.getByE164(phoneNumber)
return verificationCodeValid ? accounts.getByE164(phoneNumber) : Optional.empty();
})
.orElseThrow(ForbiddenException::new);
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {

View File

@ -24,10 +24,14 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.ws.rs.Path;
@ -73,6 +77,7 @@ import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -82,12 +87,15 @@ class DeviceControllerTest {
static class DumbVerificationDeviceController extends DeviceController {
public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices,
byte[] linkDeviceSecret,
AccountsManager accounts,
MessagesManager messages,
KeysManager keys,
RateLimiters rateLimiters,
Map<String, Integer> deviceConfiguration) {
super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration);
Map<String, Integer> deviceConfiguration,
Clock clock) {
super(pendingDevices, linkDeviceSecret, accounts, messages, keys, rateLimiters, deviceConfiguration, clock);
}
@Override
@ -106,8 +114,17 @@ class DeviceControllerTest {
private static Account maxedAccount = mock(Account.class);
private static Device masterDevice = mock(Device.class);
private static ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static TestClock testClock = TestClock.now();
private static DeviceController deviceController = new DumbVerificationDeviceController(pendingDevicesManager,
generateLinkDeviceSecret(),
accountsManager,
messagesManager,
keysManager,
rateLimiters,
deviceConfiguration,
testClock);
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
@ -116,14 +133,15 @@ class DeviceControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keysManager,
rateLimiters,
deviceConfiguration))
.addResource(deviceController)
.build();
private static byte[] generateLinkDeviceSecret() {
final byte[] linkDeviceSecret = new byte[32];
new SecureRandom().nextBytes(linkDeviceSecret);
return linkDeviceSecret;
}
@BeforeEach
void setup() {
@ -174,6 +192,8 @@ class DeviceControllerTest {
masterDevice,
clientPresenceManager
);
testClock.unpin();
}
@Test
@ -208,6 +228,28 @@ class DeviceControllerTest {
verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID);
}
@Test
void validDeviceRegisterTestSignedToken() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID);
final DeviceResponse response = resources.getJerseyTest()
.target("/v1/devices/" + verificationToken)
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(new AccountAttributes(false, 1234, null,
null, true, null),
MediaType.APPLICATION_JSON_TYPE),
DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
}
@Test
void verifyDeviceWithNullAccountAttributes() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
@ -328,6 +370,54 @@ class DeviceControllerTest {
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get()));
}
@ParameterizedTest
@MethodSource("linkDeviceAtomic")
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void linkDeviceAtomicWithVerificationToken(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final Optional<ECSignedPreKey> aciSignedPreKey;
final Optional<ECSignedPreKey> pniSignedPreKey;
final Optional<KEMSignedPreKey> aciPqLastResortPreKey;
final Optional<KEMSignedPreKey> pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = Optional.of(KeysHelper.signedECPreKey(1, aciIdentityKeyPair));
pniSignedPreKey = Optional.of(KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
aciPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair));
pniPqLastResortPreKey = Optional.of(KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
when(account.getIdentityKey()).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getPhoneNumberIdentityKey()).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
new AccountAttributes(fetchesMessages, 1234, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
final DeviceResponse response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@ -841,4 +931,57 @@ class DeviceControllerTest {
verify(keysManager).delete(AuthHelper.VALID_UUID, deviceId);
}
@Test
void checkVerificationToken() {
final UUID uuid = UUID.randomUUID();
assertEquals(Optional.of(uuid),
deviceController.checkVerificationToken(deviceController.generateVerificationToken(uuid)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
testClock.pin(currentTime);
assertEquals(Optional.empty(),
deviceController.checkVerificationToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() {
final Instant tokenTimestamp = testClock.instant();
return Stream.of(
// Expired token
Arguments.of(deviceController.generateVerificationToken(UUID.randomUUID()),
tokenTimestamp.plus(DeviceController.TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
}