Assume that all devices have signed pre-keys

This commit is contained in:
Jon Chambers 2023-12-08 14:27:53 -05:00 committed by Jon Chambers
parent c29113d17a
commit 28a981f29f
6 changed files with 34 additions and 82 deletions

View File

@ -37,7 +37,6 @@ import javax.ws.rs.core.Response;
import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey;
@ -96,7 +95,6 @@ public class KeysController {
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Upload new prekeys", description = "Upload new pre-keys for this device.")
@ApiResponse(responseCode = "200", description = "Indicates that new keys were successfully stored.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ -277,7 +275,6 @@ public class KeysController {
@PUT
@Path("/signed")
@Consumes(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Upload a new signed prekey",
description = """
Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys instead.

View File

@ -207,8 +207,8 @@ public class Device {
public boolean isEnabled() {
boolean hasChannel = fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId());
return (id == PRIMARY_ID && hasChannel && signedPreKey != null) ||
(id != PRIMARY_ID && hasChannel && signedPreKey != null && lastSeen > (System.currentTimeMillis() - TimeUnit.DAYS.toMillis(30)));
return (id == PRIMARY_ID && hasChannel) ||
(id != PRIMARY_ID && hasChannel && lastSeen > (System.currentTimeMillis() - TimeUnit.DAYS.toMillis(30)));
}
public boolean getFetchesMessages() {

View File

@ -202,11 +202,9 @@ class AuthEnablementRefreshRequirementProviderTest {
}
static Stream<Arguments> testDeviceEnabledChanged() {
final byte deviceId1 = Device.PRIMARY_ID;
final byte deviceId2 = 2;
final byte deviceId3 = 3;
return Stream.of(
Arguments.of(Map.of(deviceId1, false, deviceId2, false), Map.of(deviceId1, true, deviceId2, false)),
Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
@ -274,32 +272,6 @@ class AuthEnablementRefreshRequirementProviderTest {
verifyNoMoreInteractions(clientPresenceManager);
}
@Test
void testPrimaryDeviceDisabledAndDeviceRemoved() {
assert account.getPrimaryDevice().isEnabled();
final Set<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final byte deletedDeviceId = 2;
assertTrue(initialDeviceIds.remove(deletedDeviceId));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/disablePrimaryDeviceAndDeleteDevice/" + deletedDeviceId)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.post(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
assertTrue(account.getDevice(deletedDeviceId).isEmpty());
initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
verify(clientPresenceManager).disconnectPresence(account.getUuid(), deletedDeviceId);
verifyNoMoreInteractions(clientPresenceManager);
}
@Test
void testOnEvent() {
Response response = resources.getJerseyTest()

View File

@ -49,6 +49,7 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
@ -1351,6 +1352,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1403,6 +1408,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1463,6 +1472,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1528,12 +1541,6 @@ class AccountsManagerTest {
deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
assertThrows(MismatchedDevicesException.class,
@ -1558,12 +1565,6 @@ class AccountsManagerTest {
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
assertThrows(MismatchedDevicesException.class,

View File

@ -20,7 +20,7 @@ class DeviceTest {
@ParameterizedTest
@MethodSource
void testIsEnabled(final boolean primary, final boolean fetchesMessages, final String apnId, final String gcmId,
final ECSignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) {
final Duration timeSinceLastSeen, final boolean expectEnabled) {
final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis();
@ -29,7 +29,6 @@ class DeviceTest {
device.setFetchesMessages(fetchesMessages);
device.setApnId(apnId);
device.setGcmId(gcmId);
device.setSignedPreKey(signedPreKey);
device.setCreated(lastSeen);
device.setLastSeen(lastSeen);
@ -38,39 +37,23 @@ class DeviceTest {
private static Stream<Arguments> testIsEnabled() {
return Stream.of(
// primary fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled
Arguments.of(true, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(true, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(true, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), true),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true)
// primary fetchesMessages apnId gcmId lastSeen expectEnabled
Arguments.of(true, false, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, null, "gcm-id", Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, Duration.ofDays(60), true),
Arguments.of(true, false, "apn-id", null, Duration.ofDays(1), true),
Arguments.of(true, true, null, null, Duration.ofDays(60), true),
Arguments.of(true, true, null, null, Duration.ofDays(1), true),
Arguments.of(false, false, null, null, Duration.ofDays(60), false),
Arguments.of(false, false, null, null, Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, Duration.ofDays(1), true),
Arguments.of(false, true, null, null, Duration.ofDays(60), false),
Arguments.of(false, true, null, null, Duration.ofDays(1), true)
);
}
}

View File

@ -47,12 +47,11 @@ public class DevicesHelper {
public static void setEnabled(Device device, boolean enabled) {
if (enabled) {
device.setSignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));
device.setPhoneNumberIdentitySignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));
device.setGcmId("testGcmId" + RANDOM.nextLong());
device.setLastSeen(Util.todayInMillis());
} else {
device.setSignedPreKey(null);
device.setLastSeen(0);
}
// fail fast, to guard against a change to the isEnabled() implementation causing unexpected test behavior