Allow, but do not require, message delivery to devices without active delivery channels

This commit is contained in:
Jon Chambers 2024-06-25 09:53:31 -04:00 committed by GitHub
parent f5ce34fb69
commit d306cafbcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 189 additions and 281 deletions

View File

@ -22,12 +22,12 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
/** /**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and * This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in
* {@link Device#hasMessageDeliveryChannel()}. * {@link Device#hasMessageDeliveryChannel()}.
* <p> * <p>
* If a change in {@link Account#isEnabled()} or any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active * If a change in any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active WebSocket
* WebSocket connections for the account must be closed in order for clients to get a refreshed * connections for the account must be closed in order for clients to get a refreshed {@link io.dropwizard.auth.Auth}
* {@link io.dropwizard.auth.Auth} object with a current device list. * object with a current device list.
* *
* @see AuthenticatedAccount * @see AuthenticatedAccount
*/ */
@ -48,9 +48,8 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@Override @Override
public void handleRequestFiltered(final RequestEvent requestEvent) { public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(ChangesDeviceEnabledState.class) != null) { if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(ChangesDeviceEnabledState.class) != null) {
// The authenticated principal, if any, will be available after filters have run. // The authenticated principal, if any, will be available after filters have run. Now that the account is known,
// Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out // capture a snapshot of the account's devices before carrying out the requests business logic.
// the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()).ifPresent(account -> ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()).ifPresent(account ->
setAccount(requestEvent.getContainerRequest(), account)); setAccount(requestEvent.getContainerRequest(), account));
} }
@ -66,8 +65,8 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@Override @Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) { public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did // Now that the request is finished, check whether `hasMessageDeliveryChannel` changed for any of the devices. If
// change or if a devices was added or removed, all devices must disconnect and reauthenticate. // the value did change or if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) { if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled = @SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =

View File

@ -18,6 +18,8 @@ import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class OptionalAccess { public class OptionalAccess {
public static String ALL_DEVICES_SELECTOR = "*";
public static void verify(Optional<Account> requestAccount, public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey, Optional<Anonymous> accessKey,
Optional<Account> targetAccount, Optional<Account> targetAccount,
@ -26,12 +28,12 @@ public class OptionalAccess {
try { try {
verify(requestAccount, accessKey, targetAccount); verify(requestAccount, accessKey, targetAccount);
if (!deviceSelector.equals("*")) { if (!ALL_DEVICES_SELECTOR.equals(deviceSelector)) {
byte deviceId = Byte.parseByte(deviceSelector); byte deviceId = Byte.parseByte(deviceSelector);
Optional<Device> targetDevice = targetAccount.get().getDevice(deviceId); Optional<Device> targetDevice = targetAccount.get().getDevice(deviceId);
if (targetDevice.isPresent() && targetDevice.get().hasMessageDeliveryChannel()) { if (targetDevice.isPresent()) {
return; return;
} }
@ -48,11 +50,10 @@ public class OptionalAccess {
public static void verify(Optional<Account> requestAccount, public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey, Optional<Anonymous> accessKey,
Optional<Account> targetAccount) Optional<Account> targetAccount) {
{
if (requestAccount.isPresent()) { if (requestAccount.isPresent()) {
// Authenticated requests are never unauthorized; if the target exists and is enabled, return OK, otherwise throw not-found. // Authenticated requests are never unauthorized; if the target exists, return OK, otherwise throw not-found.
if (targetAccount.isPresent() && targetAccount.get().isEnabled()) { if (targetAccount.isPresent()) {
return; return;
} else { } else {
throw new NotFoundException(); throw new NotFoundException();
@ -63,7 +64,7 @@ public class OptionalAccess {
// has unrestricted unidentified access, callers need to supply a fake access key. Likewise, if // has unrestricted unidentified access, callers need to supply a fake access key. Likewise, if
// the target account does not exist, we *also* report unauthorized here (*not* not-found, // the target account does not exist, we *also* report unauthorized here (*not* not-found,
// since that would provide a free exists check). // since that would provide a free exists check).
if (accessKey.isEmpty() || !targetAccount.map(Account::isEnabled).orElse(false)) { if (accessKey.isEmpty() || targetAccount.isEmpty()) {
throw new NotAuthorizedException(Response.Status.UNAUTHORIZED); throw new NotAuthorizedException(Response.Status.UNAUTHORIZED);
} }

View File

@ -313,7 +313,7 @@ public class KeysController {
@ApiResponse(responseCode = "200", description = "Indicates at least one prekey was available for at least one requested device.", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "Indicates at least one prekey was available for at least one requested device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "400", description = "A group send endorsement and other authorization (account authentication or unidentified-access key) were both provided.") @ApiResponse(responseCode = "400", description = "A group send endorsement and other authorization (account authentication or unidentified-access key) were both provided.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed and unidentified-access key or group send endorsement token was not supplied or invalid.") @ApiResponse(responseCode = "401", description = "Account authentication check failed and unidentified-access key or group send endorsement token was not supplied or invalid.")
@ApiResponse(responseCode = "404", description = "Requested identity or device does not exist, is not active, or has no available prekeys.") @ApiResponse(responseCode = "404", description = "Requested identity or device does not exist or device has no available prekeys.")
@ApiResponse(responseCode = "429", description = "Rate limit exceeded.", headers = @Header( @ApiResponse(responseCode = "429", description = "Rate limit exceeded.", headers = @Header(
name = "Retry-After", name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed")) description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
@ -440,11 +440,11 @@ public class KeysController {
private List<Device> parseDeviceId(String deviceId, Account account) { private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) { if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::hasMessageDeliveryChannel).toList(); return account.getDevices();
} }
try { try {
byte id = Byte.parseByte(deviceId); byte id = Byte.parseByte(deviceId);
return account.getDevice(id).filter(Device::hasMessageDeliveryChannel).map(List::of).orElse(List.of()); return account.getDevice(id).map(List::of).orElse(List.of());
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
} }

View File

@ -369,7 +369,7 @@ public class MessageController {
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
} }
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice(); boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getDevices().size() > 1;
// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify // We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from // we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from

View File

@ -39,7 +39,6 @@ class KeysGrpcHelper {
: Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId))); : Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));
return devices return devices
.filter(Device::hasMessageDeliveryChannel)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(device -> Flux.merge( .flatMap(device -> Flux.merge(
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())), Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),

View File

@ -303,12 +303,6 @@ public class Account {
.allMatch(device -> device.getCapabilities() != null && predicate.test(device.getCapabilities())); .allMatch(device -> device.getCapabilities() != null && predicate.test(device.getCapabilities()));
} }
public boolean isEnabled() {
requireNotStale();
return getPrimaryDevice().hasMessageDeliveryChannel();
}
public byte getNextDeviceId() { public byte getNextDeviceId() {
requireNotStale(); requireNotStale();
@ -325,14 +319,6 @@ public class Account {
return candidateId; return candidateId;
} }
public boolean hasEnabledLinkedDevice() {
requireNotStale();
return devices.stream()
.filter(d -> Device.PRIMARY_ID != d.getId())
.anyMatch(Device::hasMessageDeliveryChannel);
}
public void setIdentityKey(final IdentityKey identityKey) { public void setIdentityKey(final IdentityKey identityKey) {
requireNotStale(); requireNotStale();
@ -503,12 +489,6 @@ public class Account {
this.discoverableByPhoneNumber = discoverableByPhoneNumber; this.discoverableByPhoneNumber = discoverableByPhoneNumber;
} }
public boolean shouldBeVisibleInDirectory() {
requireNotStale();
return isEnabled() && isDiscoverableByPhoneNumber();
}
public int getVersion() { public int getVersion() {
requireNotStale(); requireNotStale();

View File

@ -443,7 +443,7 @@ public class Accounts extends AbstractDynamoDbStore {
.expressionAttributeValues(Map.of( .expressionAttributeValues(Map.of(
":number", numberAttr, ":number", numberAttr,
":data", accountDataAttributeValue(account), ":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.isDiscoverableByPhoneNumber()),
":pni", pniAttr, ":pni", pniAttr,
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))) ":version_increment", AttributeValues.fromInt(1)))
@ -924,7 +924,7 @@ public class Accounts extends AbstractDynamoDbStore {
final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of( final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of(
":data", accountDataAttributeValue(account), ":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()), ":cds", AttributeValues.fromBool(account.isDiscoverableByPhoneNumber()),
":version", AttributeValues.fromInt(account.getVersion()), ":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1))); ":version_increment", AttributeValues.fromInt(1)));
@ -1359,7 +1359,7 @@ public class Accounts extends AbstractDynamoDbStore {
ATTR_PNI_UUID, pniUuidAttr, ATTR_PNI_UUID, pniUuidAttr,
ATTR_ACCOUNT_DATA, accountDataAttributeValue(account), ATTR_ACCOUNT_DATA, accountDataAttributeValue(account),
ATTR_VERSION, AttributeValues.fromInt(account.getVersion()), ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory()))); ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.isDiscoverableByPhoneNumber())));
// Add the UAK if it's in the account // Add the UAK if it's in the account
account.getUnidentifiedAccessKey() account.getUnidentifiedAccessKey()

View File

@ -490,7 +490,6 @@ public class AccountsManager {
account.getDevices() account.getDevices()
.stream() .stream()
.filter(Device::hasMessageDeliveryChannel)
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId()))); .forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));
account.setPhoneNumberIdentityKey(pniIdentityKey); account.setPhoneNumberIdentityKey(pniIdentityKey);

View File

@ -89,7 +89,6 @@ public class DestinationDeviceValidator {
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException { final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Byte> accountDeviceIds = account.getDevices().stream() final Set<Byte> accountDeviceIds = account.getDevices().stream()
.filter(Device::hasMessageDeliveryChannel)
.map(Device::getId) .map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId)) .filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
@ -97,6 +96,12 @@ public class DestinationDeviceValidator {
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds); final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds); missingDeviceIds.removeAll(messageDeviceIds);
// Temporarily "excuse" missing devices if they're missing a message delivery channel as a transitional measure
missingDeviceIds.removeAll(account.getDevices().stream()
.filter(device -> !device.hasMessageDeliveryChannel())
.map(Device::getId)
.collect(Collectors.toSet()));
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds); final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds); extraDeviceIds.removeAll(accountDeviceIds);

View File

@ -161,7 +161,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true); when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials); when(device.getAuthTokenHash()).thenReturn(credentials);
@ -191,7 +190,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true); when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials); when(device.getAuthTokenHash()).thenReturn(credentials);
@ -224,7 +222,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(authenticatedDevice)); when(account.getDevice(deviceId)).thenReturn(Optional.of(authenticatedDevice));
when(account.isEnabled()).thenReturn(accountEnabled);
when(authenticatedDevice.getId()).thenReturn(deviceId); when(authenticatedDevice.getId()).thenReturn(deviceId);
when(authenticatedDevice.hasMessageDeliveryChannel()).thenReturn(deviceEnabled); when(authenticatedDevice.hasMessageDeliveryChannel()).thenReturn(deviceEnabled);
when(authenticatedDevice.getAuthTokenHash()).thenReturn(credentials); when(authenticatedDevice.getAuthTokenHash()).thenReturn(credentials);
@ -260,7 +257,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true); when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials); when(device.getAuthTokenHash()).thenReturn(credentials);
@ -297,7 +293,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true); when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials); when(device.getAuthTokenHash()).thenReturn(credentials);
@ -325,7 +320,6 @@ class AccountAuthenticatorTest {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId); when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true); when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials); when(device.getAuthTokenHash()).thenReturn(credentials);

View File

@ -5,152 +5,141 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.nio.charset.StandardCharsets;
import java.util.Base64; import java.util.Base64;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import org.junit.jupiter.api.Test; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class OptionalAccessTest { class OptionalAccessTest {
@Test @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void testUnidentifiedMissingTarget() { @ParameterizedTest
try { @MethodSource
OptionalAccess.verify(Optional.empty(), Optional.empty(), Optional.empty()); void verify(final Optional<Account> requestAccount,
throw new AssertionError("should fail"); final Optional<Anonymous> accessKey,
} catch (WebApplicationException e) { final Optional<Account> targetAccount,
assertEquals(e.getResponse().getStatus(), 401); final String deviceSelector,
} final OptionalInt expectedStatusCode) {
expectedStatusCode.ifPresentOrElse(statusCode -> {
final WebApplicationException webApplicationException = assertThrows(WebApplicationException.class,
() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector));
assertEquals(statusCode, webApplicationException.getResponse().getStatus());
}, () -> assertDoesNotThrow(() -> OptionalAccess.verify(requestAccount, accessKey, targetAccount, deviceSelector)));
} }
@Test private static List<Arguments> verify() {
void testUnidentifiedMissingTargetDevice() { final String unidentifiedAccessKey = RandomStringUtils.randomAlphanumeric(16);
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try { final Anonymous correctUakHeader =
OptionalAccess.verify(Optional.empty(), Optional.of(new Anonymous(Base64.getEncoder().encodeToString("1234".getBytes()))), Optional.of(account), "10"); new Anonymous(Base64.getEncoder().encodeToString(unidentifiedAccessKey.getBytes()));
} catch (WebApplicationException e) {
assertEquals(e.getResponse().getStatus(), 401);
}
}
@Test final Anonymous incorrectUakHeader =
void testUnidentifiedBadTargetDevice() { new Anonymous(Base64.getEncoder().encodeToString((unidentifiedAccessKey + "-incorrect").getBytes()));
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try { final Account targetAccount = mock(Account.class);
OptionalAccess.verify(Optional.empty(), Optional.of(new Anonymous(Base64.getEncoder().encodeToString("1234".getBytes()))), Optional.of(account), "$$"); when(targetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
} catch (WebApplicationException e) { when(targetAccount.getUnidentifiedAccessKey())
assertEquals(e.getResponse().getStatus(), 422); .thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8)));
}
}
final Account allowAllTargetAccount = mock(Account.class);
when(allowAllTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
when(allowAllTargetAccount.isUnrestrictedUnidentifiedAccess()).thenReturn(true);
@Test final Account noUakTargetAccount = mock(Account.class);
void testUnidentifiedBadCode() { when(noUakTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
Account account = mock(Account.class); when(noUakTargetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.empty());
when(account.isEnabled()).thenReturn(true);
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try { final Account inactiveTargetAccount = mock(Account.class);
OptionalAccess.verify(Optional.empty(), Optional.of(new Anonymous(Base64.getEncoder().encodeToString("5678".getBytes()))), Optional.of(account)); when(inactiveTargetAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(mock(Device.class)));
throw new AssertionError("should fail"); when(inactiveTargetAccount.getUnidentifiedAccessKey())
} catch (WebApplicationException e) { .thenReturn(Optional.of(unidentifiedAccessKey.getBytes(StandardCharsets.UTF_8)));
assertEquals(e.getResponse().getStatus(), 401);
}
}
@Test return List.of(
void testIdentifiedMissingTarget() { // Unidentified caller; correct UAK
Account account = mock(Account.class); Arguments.of(Optional.empty(),
when(account.isEnabled()).thenReturn(true); Optional.of(correctUakHeader),
Optional.of(targetAccount),
OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
try { // Identified caller; no UAK needed
OptionalAccess.verify(Optional.of(account), Optional.empty(), Optional.empty()); Arguments.of(Optional.of(mock(Account.class)),
throw new AssertionError("should fail"); Optional.empty(),
} catch (WebApplicationException e) { Optional.of(targetAccount),
assertEquals(e.getResponse().getStatus(), 404); OptionalAccess.ALL_DEVICES_SELECTOR,
} OptionalInt.empty()),
}
@Test // Unidentified caller; target account not found
void testUnsolicitedBadTarget() { Arguments.of(Optional.empty(),
Account account = mock(Account.class); Optional.empty(),
when(account.isUnrestrictedUnidentifiedAccess()).thenReturn(false); Optional.empty(),
when(account.isEnabled()).thenReturn(true); OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.of(401)),
try { // Identified caller; target account not found
OptionalAccess.verify(Optional.empty(), Optional.empty(), Optional.of(account)); Arguments.of(Optional.of(mock(Account.class)),
throw new AssertionError("should fail"); Optional.empty(),
} catch (WebApplicationException e) { Optional.empty(),
assertEquals(e.getResponse().getStatus(), 401); OptionalAccess.ALL_DEVICES_SELECTOR,
} OptionalInt.of(404)),
}
@Test // Unidentified caller; target account found, but target device not found
void testUnsolicitedGoodTarget() { Arguments.of(Optional.empty(),
Account account = mock(Account.class); Optional.of(correctUakHeader),
Anonymous random = mock(Anonymous.class); Optional.of(targetAccount),
when(account.isUnrestrictedUnidentifiedAccess()).thenReturn(true); String.valueOf(Device.PRIMARY_ID + 1),
when(account.isEnabled()).thenReturn(true); OptionalInt.of(401)),
OptionalAccess.verify(Optional.empty(), Optional.of(random), Optional.of(account));
}
@Test // Unidentified caller; target account found, but incorrect UAK provided
void testUnidentifiedGoodTarget() { Arguments.of(Optional.empty(),
Account account = mock(Account.class); Optional.of(incorrectUakHeader),
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes())); Optional.of(targetAccount),
when(account.isEnabled()).thenReturn(true); OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalAccess.verify(Optional.empty(), Optional.of(new Anonymous(Base64.getEncoder().encodeToString("1234".getBytes()))), Optional.of(account)); OptionalInt.of(401)),
}
@Test // Unidentified caller; target account found, but has no UAK
void testUnidentifiedTargetMissingAccessKey() { Arguments.of(Optional.empty(),
Account account = mock(Account.class); Optional.of(correctUakHeader),
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.empty()); Optional.of(noUakTargetAccount),
when(account.isEnabled()).thenReturn(true); OptionalAccess.ALL_DEVICES_SELECTOR,
try { OptionalInt.of(401)),
OptionalAccess.verify(
Optional.empty(),
Optional.of(new Anonymous(Base64.getEncoder().encodeToString("1234".getBytes()))),
Optional.of(account));
throw new AssertionError("should fail");
} catch (WebApplicationException e) {
assertEquals(e.getResponse().getStatus(), 401);
}
}
@Test // Unidentified caller; target account found, allows unrestricted unidentified access
void testUnidentifiedInactive() { Arguments.of(Optional.empty(),
Account account = mock(Account.class); Optional.of(incorrectUakHeader),
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes())); Optional.of(allowAllTargetAccount),
when(account.isEnabled()).thenReturn(false); OptionalAccess.ALL_DEVICES_SELECTOR,
OptionalInt.empty()),
try { // Unidentified caller; target account found, but inactive
OptionalAccess.verify(Optional.empty(), Optional.of(new Anonymous(Base64.getEncoder().encodeToString("1234".getBytes()))), Optional.of(account)); Arguments.of(Optional.empty(),
throw new AssertionError(); Optional.of(correctUakHeader),
} catch (WebApplicationException e) { Optional.of(inactiveTargetAccount),
assertEquals(e.getResponse().getStatus(), 401); OptionalAccess.ALL_DEVICES_SELECTOR,
} OptionalInt.empty()),
}
@Test // Malformed device ID
void testIdentifiedGoodTarget() { Arguments.of(Optional.empty(),
Account source = mock(Account.class); Optional.of(correctUakHeader),
Account target = mock(Account.class); Optional.of(targetAccount),
when(target.isEnabled()).thenReturn(true); "not a valid identifier",
OptionalAccess.verify(Optional.of(source), Optional.empty(), Optional.of(target)); OptionalInt.of(422))
);
} }
} }

View File

@ -143,7 +143,6 @@ class DeviceControllerTest {
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.isEnabled()).thenReturn(false);
when(account.isPaymentActivationSupported()).thenReturn(false); when(account.isPaymentActivationSupported()).thenReturn(false);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));

View File

@ -240,7 +240,6 @@ class KeysControllerTest {
when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4)); when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty()); when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices); when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY); when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(existsAccount.getIdentityKey(IdentityType.PNI)).thenReturn(PNI_IDENTITY_KEY); when(existsAccount.getIdentityKey(IdentityType.PNI)).thenReturn(PNI_IDENTITY_KEY);
when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER); when(existsAccount.getNumber()).thenReturn(EXISTS_NUMBER);
@ -676,7 +675,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getDevicesCount()).isEqualTo(4);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
@ -711,12 +710,15 @@ class KeysControllerTest {
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@ -746,7 +748,7 @@ class KeysControllerTest {
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponse.class); .get(PreKeyResponse.class);
assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getDevicesCount()).isEqualTo(4);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
@ -789,10 +791,13 @@ class KeysControllerTest {
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID3);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }

View File

@ -1584,18 +1584,20 @@ class MessageControllerTest {
void sendMultiRecipientMessageMismatchedDevices(final ServiceIdentifier serviceIdentifier) void sendMultiRecipientMessageMismatchedDevices(final ServiceIdentifier serviceIdentifier)
throws JsonProcessingException { throws JsonProcessingException {
final byte extraDeviceId = MULTI_DEVICE_ID3 + 1;
final List<Recipient> recipients = List.of( final List<Recipient> recipients = List.of(
new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]), new Recipient(serviceIdentifier, MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]), new Recipient(serviceIdentifier, MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, new byte[48]),
new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48])); new Recipient(serviceIdentifier, MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, new byte[48]),
new Recipient(serviceIdentifier, extraDeviceId, 1234, new byte[48]));
// initialize our binary payload and create an input stream // initialize our binary payload and create an input stream
byte[] buffer = new byte[2048]; final byte[] buffer = new byte[2048];
// InputStream stream = initializeMultiPayload(recipientUUID, buffer); final InputStream stream = initializeMultiPayload(recipients, buffer, true);
InputStream stream = initializeMultiPayload(recipients, buffer, true);
// set up the entity to use in our PUT request // set up the entity to use in our PUT request
Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE); final Entity<InputStream> entity = Entity.entity(stream, MultiRecipientMessageProvider.MEDIA_TYPE);
// start building the request // start building the request
final Invocation.Builder invocationBuilder = resources final Invocation.Builder invocationBuilder = resources
@ -1619,7 +1621,7 @@ class MessageControllerTest {
.constructCollectionType(List.class, AccountMismatchedDevices.class)); .constructCollectionType(List.class, AccountMismatchedDevices.class));
assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier, assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(Collections.emptyList(), List.of(MULTI_DEVICE_ID3)))), new MismatchedDevices(Collections.emptyList(), List.of(extraDeviceId)))),
mismatchedDevices); mismatchedDevices);
} }
} }

View File

@ -199,7 +199,6 @@ class ProfileControllerTest {
when(profileAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY); when(profileAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_TWO_PHONE_NUMBER_IDENTITY_KEY);
when(profileAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID_TWO); when(profileAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID_TWO);
when(profileAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI_TWO); when(profileAccount.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI_TWO);
when(profileAccount.isEnabled()).thenReturn(true);
when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty()); when(profileAccount.getCurrentProfileVersion()).thenReturn(Optional.empty());
when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH)); when(profileAccount.getUsernameHash()).thenReturn(Optional.of(USERNAME_HASH));
when(profileAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); when(profileAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
@ -209,7 +208,6 @@ class ProfileControllerTest {
when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(capabilitiesAccount.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(capabilitiesAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTITY_KEY); when(capabilitiesAccount.getIdentityKey(IdentityType.ACI)).thenReturn(ACCOUNT_IDENTITY_KEY);
when(capabilitiesAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY); when(capabilitiesAccount.getIdentityKey(IdentityType.PNI)).thenReturn(ACCOUNT_PHONE_NUMBER_IDENTITY_KEY);
when(capabilitiesAccount.isEnabled()).thenReturn(true);
when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty()); when(accountsManager.getByServiceIdentifier(any())).thenReturn(Optional.empty());
@ -1012,7 +1010,6 @@ class ProfileControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getCurrentProfileVersion()).thenReturn(Optional.of(versionHex("version"))); when(account.getCurrentProfileVersion()).thenReturn(Optional.of(versionHex("version")));
when(account.isEnabled()).thenReturn(true);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(profilesManager.get(any(), any())).thenReturn(Optional.empty()); when(profilesManager.get(any(), any())).thenReturn(Optional.empty());
@ -1168,7 +1165,6 @@ class ProfileControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version)); when(account.getCurrentProfileVersion()).thenReturn(Optional.of(version));
when(account.isEnabled()).thenReturn(true);
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
final Instant expiration = Instant.now().plus(org.whispersystems.textsecuregcm.util.ProfileHelper.EXPIRING_PROFILE_KEY_CREDENTIAL_EXPIRATION) final Instant expiration = Instant.now().plus(org.whispersystems.textsecuregcm.util.ProfileHelper.EXPIRING_PROFILE_KEY_CREDENTIAL_EXPIRATION)
@ -1234,7 +1230,6 @@ class ProfileControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.isEnabled()).thenReturn(true);
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY)); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of(UNIDENTIFIED_ACCESS_KEY));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(account)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(account));

View File

@ -9,8 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException; import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
@ -26,7 +26,6 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;

View File

@ -96,32 +96,6 @@ class AccountTest {
} }
@Test
void testIsEnabled() {
final Device enabledPrimaryDevice = mock(Device.class);
final Device enabledLinkedDevice = mock(Device.class);
final Device disabledPrimaryDevice = mock(Device.class);
final Device disabledLinkedDevice = mock(Device.class);
when(enabledPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinkedDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(disabledPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(false);
when(disabledLinkedDevice.hasMessageDeliveryChannel()).thenReturn(false);
when(enabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
final byte deviceId2 = 2;
when(enabledLinkedDevice.getId()).thenReturn(deviceId2);
when(disabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(disabledLinkedDevice.getId()).thenReturn(deviceId2);
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice)).isEnabled());
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice, enabledLinkedDevice)).isEnabled());
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice, disabledLinkedDevice)).isEnabled());
assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledPrimaryDevice)).isEnabled());
assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledPrimaryDevice, enabledLinkedDevice)).isEnabled());
assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledPrimaryDevice, disabledLinkedDevice)).isEnabled());
}
@Test @Test
void testIsTransferSupported() { void testIsTransferSupported() {
final Device transferCapablePrimaryDevice = mock(Device.class); final Device transferCapablePrimaryDevice = mock(Device.class);
@ -308,49 +282,4 @@ class AccountTest {
final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get(); final JsonFilter jsonFilterAnnotation = (JsonFilter) maybeJsonFilterAnnotation.get();
assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value()); assertEquals(Account.class.getSimpleName(), jsonFilterAnnotation.value());
} }
@ParameterizedTest
@MethodSource
public void testHasEnabledLinkedDevice(final Account account, final boolean expect) {
assertEquals(expect, account.hasEnabledLinkedDevice());
}
static Stream<Arguments> testHasEnabledLinkedDevice() {
final Device enabledPrimary = mock(Device.class);
when(enabledPrimary.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
final Device disabledPrimary = mock(Device.class);
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
final byte linked1DeviceId = Device.PRIMARY_ID + 1;
final Device enabledLinked1 = mock(Device.class);
when(enabledLinked1.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);
final Device disabledLinked1 = mock(Device.class);
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);
final byte linked2DeviceId = Device.PRIMARY_ID + 2;
final Device enabledLinked2 = mock(Device.class);
when(enabledLinked2.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);
final Device disabledLinked2 = mock(Device.class);
when(disabledLinked2.getId()).thenReturn(linked2DeviceId);
return Stream.of(
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary)), false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123", List.of(enabledPrimary, disabledLinked1)),
false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, disabledLinked1, disabledLinked2)), false),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, enabledLinked1, disabledLinked2)), true),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(enabledPrimary, disabledLinked1, enabledLinked2)), true),
Arguments.of(AccountsHelper.generateTestAccount("+14155550123",
List.of(disabledLinked2, enabledLinked1, enabledLinked2)), true)
);
}
} }

View File

@ -1066,11 +1066,13 @@ class AccountsManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair(); final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of( final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair),
deviceId3, KeysHelper.signedECPreKey(3, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of( final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(4, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); deviceId2, KeysHelper.signedKEMPreKey(5, identityKeyPair),
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); deviceId3, KeysHelper.signedKEMPreKey(6, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202, deviceId3, 203);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
@ -1097,7 +1099,9 @@ class AccountsManagerTest {
verify(keysManager).getPqEnabledDevices(uuid); verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any());
verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId3), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any());
verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(deviceId3), any());
verifyNoMoreInteractions(keysManager); verifyNoMoreInteractions(keysManager);
} }

View File

@ -315,7 +315,7 @@ class AccountsTest {
Accounts.ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()), Accounts.ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
Accounts.ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)), Accounts.ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(account)),
Accounts.ATTR_VERSION, AttributeValues.fromInt(account.getVersion()), Accounts.ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
Accounts.ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory()))) Accounts.ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.isDiscoverableByPhoneNumber())))
.build()) .build())
.build(); .build();

View File

@ -145,12 +145,10 @@ public class AccountsHelper {
case "getDevices" -> when(updatedAccount.getDevices()).thenAnswer(stubbing); case "getDevices" -> when(updatedAccount.getDevices()).thenAnswer(stubbing);
case "getDevice" -> when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); case "getDevice" -> when(updatedAccount.getDevice(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);
case "getPrimaryDevice" -> when(updatedAccount.getPrimaryDevice()).thenAnswer(stubbing); case "getPrimaryDevice" -> when(updatedAccount.getPrimaryDevice()).thenAnswer(stubbing);
case "isEnabled" -> when(updatedAccount.isEnabled()).thenAnswer(stubbing);
case "isDiscoverableByPhoneNumber" -> when(updatedAccount.isDiscoverableByPhoneNumber()).thenAnswer(stubbing); case "isDiscoverableByPhoneNumber" -> when(updatedAccount.isDiscoverableByPhoneNumber()).thenAnswer(stubbing);
case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing); case "getNextDeviceId" -> when(updatedAccount.getNextDeviceId()).thenAnswer(stubbing);
case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing); case "isPaymentActivationSupported" -> when(updatedAccount.isPaymentActivationSupported()).thenAnswer(stubbing);
case "isDeleteSyncSupported" -> when(updatedAccount.isDeleteSyncSupported()).thenAnswer(stubbing); case "isDeleteSyncSupported" -> when(updatedAccount.isDeleteSyncSupported()).thenAnswer(stubbing);
case "hasEnabledLinkedDevice" -> when(updatedAccount.hasEnabledLinkedDevice()).thenAnswer(stubbing);
case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing); case "getRegistrationLock" -> when(updatedAccount.getRegistrationLock()).thenAnswer(stubbing);
case "getIdentityKey" -> case "getIdentityKey" ->
when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing); when(updatedAccount.getIdentityKey(stubbing.getInvocation().getArgument(0))).thenAnswer(stubbing);

View File

@ -159,8 +159,6 @@ public class AuthHelper {
when(UNDISCOVERABLE_ACCOUNT.getDevices()).thenReturn(List.of(UNDISCOVERABLE_DEVICE)); when(UNDISCOVERABLE_ACCOUNT.getDevices()).thenReturn(List.of(UNDISCOVERABLE_DEVICE));
when(VALID_ACCOUNT_3.getDevices()).thenReturn(List.of(VALID_DEVICE_3_PRIMARY, VALID_DEVICE_3_LINKED)); when(VALID_ACCOUNT_3.getDevices()).thenReturn(List.of(VALID_DEVICE_3_PRIMARY, VALID_DEVICE_3_LINKED));
when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true);
when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER); when(VALID_ACCOUNT.getNumber()).thenReturn(VALID_NUMBER);
when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID); when(VALID_ACCOUNT.getUuid()).thenReturn(VALID_UUID);
when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI); when(VALID_ACCOUNT.getPhoneNumberIdentifier()).thenReturn(VALID_PNI);
@ -180,11 +178,6 @@ public class AuthHelper {
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_3); when(VALID_ACCOUNT_3.getIdentifier(IdentityType.ACI)).thenReturn(VALID_UUID_3);
when(VALID_ACCOUNT_3.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_3); when(VALID_ACCOUNT_3.getIdentifier(IdentityType.PNI)).thenReturn(VALID_PNI_3);
when(VALID_ACCOUNT.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT_TWO.isEnabled()).thenReturn(true);
when(UNDISCOVERABLE_ACCOUNT.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT_3.isEnabled()).thenReturn(true);
when(VALID_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(true); when(VALID_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(true);
when(VALID_ACCOUNT_TWO.isDiscoverableByPhoneNumber()).thenReturn(true); when(VALID_ACCOUNT_TWO.isDiscoverableByPhoneNumber()).thenReturn(true);
when(UNDISCOVERABLE_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(false); when(UNDISCOVERABLE_ACCOUNT.isDiscoverableByPhoneNumber()).thenReturn(false);
@ -284,7 +277,6 @@ public class AuthHelper {
when(account.getPrimaryDevice()).thenReturn(device); when(account.getPrimaryDevice()).thenReturn(device);
when(account.getNumber()).thenReturn(number); when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.isEnabled()).thenReturn(true);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account)); when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
} }

View File

@ -119,35 +119,44 @@ class DestinationDeviceValidatorTest {
return account; return account;
} }
static Stream<Arguments> validateCompleteDeviceListSource() { static Stream<Arguments> validateCompleteDeviceList() {
final byte id1 = 1; final byte id1 = 1;
final byte id2 = 2; final byte id2 = 2;
final byte id3 = 3; final byte id3 = 3;
return Stream.of( return Stream.of(
// Device IDs provided for all enabled devices
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id3), Set.of(id1, id3),
null, null,
null, null,
Collections.emptySet()), Collections.emptySet()),
// Device ID provided for disabled device
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id2, id3), Set.of(id1, id2, id3),
null, null,
Set.of(id2), null,
Collections.emptySet()), Collections.emptySet()),
// Device ID omitted for enabled device
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1), Set.of(id1),
Set.of(id3), Set.of(id3),
null, null,
Collections.emptySet()), Collections.emptySet()),
// Device ID included for disabled device, omitted for enabled device
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1, id2), Set.of(id1, id2),
Set.of(id3), Set.of(id3),
Set.of(id2), null,
Collections.emptySet()), Collections.emptySet()),
// Device ID omitted for enabled device, included for device in excluded list
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id1), Set.of(id1),
@ -155,13 +164,17 @@ class DestinationDeviceValidatorTest {
Set.of(id1), Set.of(id1),
Set.of(id1) Set.of(id1)
), ),
// Device ID omitted for enabled device, included for disabled device, omitted for excluded device
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id2), Set.of(id2),
Set.of(id3), Set.of(id3),
Set.of(id2), null,
Set.of(id1) Set.of(id1)
), ),
// Device ID included for enabled device, omitted for excluded device
arguments( arguments(
mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)),
Set.of(id3), Set.of(id3),
@ -173,8 +186,8 @@ class DestinationDeviceValidatorTest {
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("validateCompleteDeviceListSource") @MethodSource
void testValidateCompleteDeviceList( void validateCompleteDeviceList(
Account account, Account account,
Set<Byte> deviceIds, Set<Byte> deviceIds,
Collection<Byte> expectedMissingDeviceIds, Collection<Byte> expectedMissingDeviceIds,

View File

@ -10,5 +10,11 @@
"destinationDeviceId" : 3, "destinationDeviceId" : 3,
"content" : "Zm9vYmFyego", "content" : "Zm9vYmFyego",
"timestamp" : 1234 "timestamp" : 1234
}] },
{
"type" : 1,
"destinationDeviceId" : 4,
"content" : "Zm9vYmFyego",
"timestamp" : 1234
}]
} }

View File

@ -1,4 +1,4 @@
{ {
"missingDevices" : [2], "missingDevices" : [2],
"extraDevices" : [3] "extraDevices" : [4]
} }