Assume that all devices have signed pre-keys

This commit is contained in:
Jon Chambers 2023-12-08 14:27:53 -05:00 committed by Jon Chambers
parent c29113d17a
commit 28a981f29f
6 changed files with 34 additions and 82 deletions

View File

@ -37,7 +37,6 @@ import javax.ws.rs.core.Response;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECPreKey;
@ -96,7 +95,6 @@ public class KeysController {
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Upload new prekeys", description = "Upload new pre-keys for this device.") @Operation(summary = "Upload new prekeys", description = "Upload new pre-keys for this device.")
@ApiResponse(responseCode = "200", description = "Indicates that new keys were successfully stored.") @ApiResponse(responseCode = "200", description = "Indicates that new keys were successfully stored.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ -277,7 +275,6 @@ public class KeysController {
@PUT @PUT
@Path("/signed") @Path("/signed")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
@Operation(summary = "Upload a new signed prekey", @Operation(summary = "Upload a new signed prekey",
description = """ description = """
Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys instead. Upload a new signed elliptic-curve prekey for this device. Deprecated; use PUT /v2/keys instead.

View File

@ -207,8 +207,8 @@ public class Device {
public boolean isEnabled() { public boolean isEnabled() {
boolean hasChannel = fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId()); boolean hasChannel = fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId());
return (id == PRIMARY_ID && hasChannel && signedPreKey != null) || return (id == PRIMARY_ID && hasChannel) ||
(id != PRIMARY_ID && hasChannel && signedPreKey != null && lastSeen > (System.currentTimeMillis() - TimeUnit.DAYS.toMillis(30))); (id != PRIMARY_ID && hasChannel && lastSeen > (System.currentTimeMillis() - TimeUnit.DAYS.toMillis(30)));
} }
public boolean getFetchesMessages() { public boolean getFetchesMessages() {

View File

@ -202,11 +202,9 @@ class AuthEnablementRefreshRequirementProviderTest {
} }
static Stream<Arguments> testDeviceEnabledChanged() { static Stream<Arguments> testDeviceEnabledChanged() {
final byte deviceId1 = Device.PRIMARY_ID;
final byte deviceId2 = 2; final byte deviceId2 = 2;
final byte deviceId3 = 3; final byte deviceId3 = 3;
return Stream.of( return Stream.of(
Arguments.of(Map.of(deviceId1, false, deviceId2, false), Map.of(deviceId1, true, deviceId2, false)),
Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)), Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)), Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)), Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
@ -274,32 +272,6 @@ class AuthEnablementRefreshRequirementProviderTest {
verifyNoMoreInteractions(clientPresenceManager); verifyNoMoreInteractions(clientPresenceManager);
} }
@Test
void testPrimaryDeviceDisabledAndDeviceRemoved() {
assert account.getPrimaryDevice().isEnabled();
final Set<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final byte deletedDeviceId = 2;
assertTrue(initialDeviceIds.remove(deletedDeviceId));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/disablePrimaryDeviceAndDeleteDevice/" + deletedDeviceId)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.post(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
assertTrue(account.getDevice(deletedDeviceId).isEmpty());
initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
verify(clientPresenceManager).disconnectPresence(account.getUuid(), deletedDeviceId);
verifyNoMoreInteractions(clientPresenceManager);
}
@Test @Test
void testOnEvent() { void testOnEvent() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()

View File

@ -49,6 +49,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -1351,6 +1352,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1403,6 +1408,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1463,6 +1472,10 @@ class AccountsManagerTest {
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102)); DevicesHelper.createDevice(deviceId2, 0L, 102));
devices.forEach(device ->
device.setSignedPreKey(KeysHelper.signedECPreKey(ThreadLocalRandom.current().nextLong(), Curve.generateKeyPair())));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
@ -1528,12 +1541,6 @@ class AccountsManagerTest {
deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair)); deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
assertThrows(MismatchedDevicesException.class, assertThrows(MismatchedDevicesException.class,
@ -1558,12 +1565,6 @@ class AccountsManagerTest {
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)); Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
assertThrows(MismatchedDevicesException.class, assertThrows(MismatchedDevicesException.class,

View File

@ -20,7 +20,7 @@ class DeviceTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testIsEnabled(final boolean primary, final boolean fetchesMessages, final String apnId, final String gcmId, void testIsEnabled(final boolean primary, final boolean fetchesMessages, final String apnId, final String gcmId,
final ECSignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) { final Duration timeSinceLastSeen, final boolean expectEnabled) {
final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis(); final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis();
@ -29,7 +29,6 @@ class DeviceTest {
device.setFetchesMessages(fetchesMessages); device.setFetchesMessages(fetchesMessages);
device.setApnId(apnId); device.setApnId(apnId);
device.setGcmId(gcmId); device.setGcmId(gcmId);
device.setSignedPreKey(signedPreKey);
device.setCreated(lastSeen); device.setCreated(lastSeen);
device.setLastSeen(lastSeen); device.setLastSeen(lastSeen);
@ -38,39 +37,23 @@ class DeviceTest {
private static Stream<Arguments> testIsEnabled() { private static Stream<Arguments> testIsEnabled() {
return Stream.of( return Stream.of(
// primary fetchesMessages apnId gcmId signedPreKey lastSeen expectEnabled // primary fetchesMessages apnId gcmId lastSeen expectEnabled
Arguments.of(true, false, null, null, null, Duration.ofDays(60), false), Arguments.of(true, false, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, null, null, null, Duration.ofDays(1), false), Arguments.of(true, false, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false), Arguments.of(true, false, null, "gcm-id", Duration.ofDays(60), true),
Arguments.of(true, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false), Arguments.of(true, false, null, "gcm-id", Duration.ofDays(1), true),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(60), false), Arguments.of(true, false, "apn-id", null, Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", null, Duration.ofDays(1), false), Arguments.of(true, false, "apn-id", null, Duration.ofDays(1), true),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), true), Arguments.of(true, true, null, null, Duration.ofDays(60), true),
Arguments.of(true, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(true, true, null, null, Duration.ofDays(1), true),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(60), false), Arguments.of(false, false, null, null, Duration.ofDays(60), false),
Arguments.of(true, false, "apn-id", null, null, Duration.ofDays(1), false), Arguments.of(false, false, null, null, Duration.ofDays(1), false),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), true), Arguments.of(false, false, null, "gcm-id", Duration.ofDays(60), false),
Arguments.of(true, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, false, null, "gcm-id", Duration.ofDays(1), true),
Arguments.of(true, true, null, null, null, Duration.ofDays(60), false), Arguments.of(false, false, "apn-id", null, Duration.ofDays(60), false),
Arguments.of(true, true, null, null, null, Duration.ofDays(1), false), Arguments.of(false, false, "apn-id", null, Duration.ofDays(1), true),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), true), Arguments.of(false, true, null, null, Duration.ofDays(60), false),
Arguments.of(true, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true), Arguments.of(false, true, null, null, Duration.ofDays(1), true)
Arguments.of(false, false, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, false, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", null, Duration.ofDays(1), false),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, null, "gcm-id", mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, null, Duration.ofDays(1), false),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, false, "apn-id", null, mock(ECSignedPreKey.class), Duration.ofDays(1), true),
Arguments.of(false, true, null, null, null, Duration.ofDays(60), false),
Arguments.of(false, true, null, null, null, Duration.ofDays(1), false),
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(60), false),
Arguments.of(false, true, null, null, mock(ECSignedPreKey.class), Duration.ofDays(1), true)
); );
} }
} }

View File

@ -47,12 +47,11 @@ public class DevicesHelper {
public static void setEnabled(Device device, boolean enabled) { public static void setEnabled(Device device, boolean enabled) {
if (enabled) { if (enabled) {
device.setSignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));
device.setPhoneNumberIdentitySignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair())); device.setPhoneNumberIdentitySignedPreKey(KeysHelper.signedECPreKey(RANDOM.nextLong(), Curve.generateKeyPair()));
device.setGcmId("testGcmId" + RANDOM.nextLong()); device.setGcmId("testGcmId" + RANDOM.nextLong());
device.setLastSeen(Util.todayInMillis()); device.setLastSeen(Util.todayInMillis());
} else { } else {
device.setSignedPreKey(null); device.setLastSeen(0);
} }
// fail fast, to guard against a change to the isEnabled() implementation causing unexpected test behavior // fail fast, to guard against a change to the isEnabled() implementation causing unexpected test behavior