Simplify `Device` entity

This commit is contained in:
Jon Chambers 2022-07-13 13:55:20 -04:00 committed by GitHub
parent e200548e35
commit 1dd7d33e23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 281 additions and 354 deletions

View File

@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.function.Predicate; import java.util.function.Predicate;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -46,7 +43,7 @@ public class Account {
private String username; private String username;
@JsonProperty @JsonProperty
private Set<Device> devices = new HashSet<>(); private List<Device> devices = new ArrayList<>();
@JsonProperty @JsonProperty
private String identityKey; private String identityKey;
@ -84,17 +81,6 @@ public class Account {
@JsonIgnore @JsonIgnore
private boolean canonicallyDiscoverable; private boolean canonicallyDiscoverable;
public Account() {}
@VisibleForTesting
public Account(String number, UUID uuid, final UUID phoneNumberIdentifier, Set<Device> devices, byte[] unidentifiedAccessKey) {
this.number = number;
this.uuid = uuid;
this.phoneNumberIdentifier = phoneNumberIdentifier;
this.devices = devices;
this.unidentifiedAccessKey = unidentifiedAccessKey;
}
public UUID getUuid() { public UUID getUuid() {
// this is the one method that may be called on a stale account // this is the one method that may be called on a stale account
return uuid; return uuid;
@ -150,17 +136,17 @@ public class Account {
public void addDevice(Device device) { public void addDevice(Device device) {
requireNotStale(); requireNotStale();
this.devices.remove(device); removeDevice(device.getId());
this.devices.add(device); this.devices.add(device);
} }
public void removeDevice(long deviceId) { public void removeDevice(long deviceId) {
requireNotStale(); requireNotStale();
this.devices.remove(new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "NA", 0, null)); this.devices.removeIf(device -> device.getId() == deviceId);
} }
public Set<Device> getDevices() { public List<Device> getDevices() {
requireNotStale(); requireNotStale();
return devices; return devices;
@ -175,13 +161,7 @@ public class Account {
public Optional<Device> getDevice(long deviceId) { public Optional<Device> getDevice(long deviceId) {
requireNotStale(); requireNotStale();
for (Device device : devices) { return devices.stream().filter(device -> device.getId() == deviceId).findFirst();
if (device.getId() == deviceId) {
return Optional.of(device);
}
}
return Optional.empty();
} }
public boolean isGroupsV2Supported() { public boolean isGroupsV2Supported() {

View File

@ -67,31 +67,6 @@ public class Device {
@JsonProperty @JsonProperty
private DeviceCapabilities capabilities; private DeviceCapabilities capabilities;
public Device() {}
public Device(long id, String name, String authToken, String salt,
String gcmId, String apnId,
String voipApnId, boolean fetchesMessages,
int registrationId, SignedPreKey signedPreKey,
long lastSeen, long created, String userAgent,
long uninstalledFeedback, DeviceCapabilities capabilities) {
this.id = id;
this.name = name;
this.authToken = authToken;
this.salt = salt;
this.gcmId = gcmId;
this.apnId = apnId;
this.voipApnId = voipApnId;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
this.lastSeen = lastSeen;
this.created = created;
this.userAgent = userAgent;
this.uninstalledFeedback = uninstalledFeedback;
this.capabilities = capabilities;
}
public String getApnId() { public String getApnId() {
return apnId; return apnId;
} }
@ -251,16 +226,6 @@ public class Device {
return groupsV2Supported; return groupsV2Supported;
} }
@Override
public boolean equals(Object other) {
return (other instanceof Device that) && this.id == that.id;
}
@Override
public int hashCode() {
return (int)this.id;
}
public static class DeviceCapabilities { public static class DeviceCapabilities {
@JsonProperty @JsonProperty
private boolean gv2; private boolean gv2;

View File

@ -12,7 +12,6 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
@ -41,8 +40,7 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
for (Account account : chunkAccounts) { for (Account account : chunkAccounts) {
boolean update = false; boolean update = false;
final Set<Device> devices = account.getDevices(); for (Device device : account.getDevices()) {
for (Device device : devices) {
if (deviceNeedsUpdate(device)) { if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) { if (deviceExpired(device)) {
if (device.isEnabled()) { if (device.isEnabled()) {

View File

@ -32,9 +32,9 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.Principal; import java.security.Principal;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -141,7 +141,7 @@ class AuthEnablementRefreshRequirementProviderTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
final Set<Device> devices = new HashSet<>(); final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices); when(account.getDevices()).thenReturn(devices);
LongStream.range(1, 5) LongStream.range(1, 5)

View File

@ -17,14 +17,13 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.dropwizard.auth.basic.BasicCredentials;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream; import java.util.stream.Stream;
import io.dropwizard.auth.basic.BasicCredentials;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
@ -35,12 +34,10 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
class BaseAccountAuthenticatorTest { class BaseAccountAuthenticatorTest {
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L; private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L; private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L; private final long oldTime = yesterday - 86_400_000L;
@ -59,19 +56,22 @@ class BaseAccountAuthenticatorTest {
clock = mock(Clock.class); clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock); baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), UUID.randomUUID(), // We use static UUIDs here because the UUID affects the "date last seen" offset
Set.of(new Device(1, null, null, null, acct1 = AccountsHelper.generateTestAccount("+14088675309", UUID.fromString("c139cb3e-f70c-4460-b221-815e8bdf778f"), UUID.randomUUID(), List.of(generateTestDevice(yesterday)), null);
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null); acct2 = AccountsHelper.generateTestAccount("+14088675310", UUID.fromString("30018a41-2764-4bc7-a935-775dfef84ad1"), UUID.randomUUID(), List.of(generateTestDevice(yesterday)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), UUID.randomUUID(), oldAccount = AccountsHelper.generateTestAccount("+14088675311", UUID.fromString("adfce52b-9299-4c25-9c51-412fb420c6a6"), UUID.randomUUID(), List.of(generateTestDevice(oldTime)), null);
Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), UUID.randomUUID(),
Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
AccountsHelper.setupMockUpdate(accountsManager); AccountsHelper.setupMockUpdate(accountsManager);
} }
private static Device generateTestDevice(final long lastSeen) {
final Device device = new Device();
device.setId(Device.MASTER_ID);
device.setLastSeen(lastSeen);
return device;
}
@Test @Test
void testUpdateLastSeenMiddleOfDay() { void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime)); when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));

View File

@ -5,6 +5,20 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.ws.rs.core.SecurityContext;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -13,22 +27,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import javax.annotation.Nullable;
import javax.ws.rs.core.SecurityContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
class PhoneNumberChangeRefreshRequirementProviderTest { class PhoneNumberChangeRefreshRequirementProviderTest {
private PhoneNumberChangeRefreshRequirementProvider provider; private PhoneNumberChangeRefreshRequirementProvider provider;
@ -50,7 +48,7 @@ class PhoneNumberChangeRefreshRequirementProviderTest {
when(account.getUuid()).thenReturn(ACCOUNT_UUID); when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(NUMBER); when(account.getNumber()).thenReturn(NUMBER);
when(account.getDevices()).thenReturn(Set.of(device)); when(account.getDevices()).thenReturn(List.of(device));
when(device.getId()).thenReturn(Device.MASTER_ID); when(device.getId()).thenReturn(Device.MASTER_ID);
request = mock(ContainerRequest.class); request = mock(ContainerRequest.class);

View File

@ -35,7 +35,6 @@ import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -80,6 +79,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -126,31 +126,19 @@ class MessageControllerTest {
@BeforeEach @BeforeEach
void setup() { void setup() {
Set<Device> singleDeviceList = new HashSet<>() {{ final List<Device> singleDeviceList = List.of(
add(new Device(1, null, "foo", "bar", generateTestDevice(1, 111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis())
"isgcm", null, null, false, 111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), );
System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, true, false,
false, false, false, false, false, false)));
}};
Set<Device> multiDeviceList = new HashSet<>() {{ final List<Device> multiDeviceList = List.of(
add(new Device(1, null, "foo", "bar", generateTestDevice(1, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis()),
"isgcm", null, null, false, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), generateTestDevice(2, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis()),
System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, false, false, generateTestDevice(3, 444, null, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
false, false, false, false, false, false))); );
add(new Device(2, null, "foo", "bar",
"isgcm", null, null, false, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(),
System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, false, false,
false, false, false, false, false, false)));
add(new Device(3, null, "foo", "bar",
"isgcm", null, null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31),
System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(false, false, false, false, false, false,
false, false, false, false, false, false)));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes()); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, UUID.randomUUID(), multiDeviceList, "1234".getBytes()); Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, UUID.randomUUID(), multiDeviceList, "1234".getBytes());
internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes()); internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, "1234".getBytes());
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
@ -160,6 +148,18 @@ class MessageControllerTest {
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
} }
private static Device generateTestDevice(final long id, final int registrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);
device.setSignedPreKey(signedPreKey);
device.setCreated(createdAt);
device.setLastSeen(lastSeen);
device.setGcmId("isgcm");
return device;
}
@AfterEach @AfterEach
void teardown() { void teardown() {
reset( reset(

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex; import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
@ -256,8 +255,6 @@ class AccountsManagerChangeNumberIntegrationTest {
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.update(existingAccount, a -> a.addDevice(new Device(Device.MASTER_ID, "test", "token", "salt", null, null, null, true, 1, null, 0, 0, null, 0, new DeviceCapabilities())));
accountsManager.changeNumber(account, secondNumber); accountsManager.changeNumber(account, secondNumber);
assertTrue(accountsManager.getByE164(originalNumber).isEmpty()); assertTrue(accountsManager.getByE164(originalNumber).isEmpty());

View File

@ -46,6 +46,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient; import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers; import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -183,15 +184,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
"testSignature-" + random.nextInt()); "testSignature-" + random.nextInt());
a.removeDevice(1); a.removeDevice(1);
a.addDevice(new Device(1, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), a.addDevice(DevicesHelper.createDevice(1));
"testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(),
random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(),
"testUserAgent-" + random.nextInt(), 0,
new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean())));
}); });
uuid = account.getUuid(); uuid = account.getUuid();

View File

@ -30,7 +30,6 @@ import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.time.Clock; import java.time.Clock;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -53,6 +52,7 @@ import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
class AccountsManagerTest { class AccountsManagerTest {
@ -231,7 +231,7 @@ class AccountsManagerTest {
void testGetAccountByNumberNotInCache() { void testGetAccountByNumberNotInCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null); when(commands.get(eq("AccountMap::+14152222222"))).thenReturn(null);
when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account)); when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account));
@ -255,7 +255,7 @@ class AccountsManagerTest {
void testGetAccountByUuidNotInCache() { void testGetAccountByUuidNotInCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
@ -280,7 +280,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("AccountMap::" + pni))).thenReturn(null); when(commands.get(eq("AccountMap::" + pni))).thenReturn(null);
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
@ -305,7 +305,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
String username = "test"; String username = "test";
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account.setUsername(username); account.setUsername(username);
when(commands.get(eq("AccountMap::" + username))).thenReturn(null); when(commands.get(eq("AccountMap::" + username))).thenReturn(null);
@ -331,7 +331,7 @@ class AccountsManagerTest {
void testGetAccountByNumberBrokenCache() { void testGetAccountByNumberBrokenCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!")); when(commands.get(eq("AccountMap::+14152222222"))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account)); when(accounts.getByE164(eq("+14152222222"))).thenReturn(Optional.of(account));
@ -355,7 +355,7 @@ class AccountsManagerTest {
void testGetAccountByUuidBrokenCache() { void testGetAccountByUuidBrokenCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
@ -380,7 +380,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO")); when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO"));
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
@ -405,7 +405,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
String username = "test"; String username = "test";
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account.setUsername(username); account.setUsername(username);
when(commands.get(eq("AccountMap::" + username))).thenThrow(new RedisException("OH NO")); when(commands.get(eq("AccountMap::" + username))).thenThrow(new RedisException("OH NO"));
@ -431,18 +431,18 @@ class AccountsManagerTest {
void testUpdate_optimisticLockingFailure() { void testUpdate_optimisticLockingFailure() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn( when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]))); Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class) doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER) .doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any()); .when(accounts).update(any());
when(accounts.getByAccountIdentifier(uuid)).thenReturn( when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, pni, new HashSet<>(), new byte[16]))); Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class) doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER) .doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any()); .when(accounts).update(any());
@ -460,7 +460,7 @@ class AccountsManagerTest {
@Test @Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() { void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
@ -477,10 +477,10 @@ class AccountsManagerTest {
@Test @Test
void testUpdateDevice() { void testUpdateDevice() {
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
Account account = new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
when(accounts.getByAccountIdentifier(uuid)).thenReturn( when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(new Account("+14152222222", uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]))); Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16])));
assertTrue(account.getDevices().isEmpty()); assertTrue(account.getDevices().isEmpty());
@ -596,12 +596,10 @@ class AccountsManagerTest {
@MethodSource @MethodSource
void testUpdateDirectoryQueue(final boolean visibleBeforeUpdate, final boolean visibleAfterUpdate, void testUpdateDirectoryQueue(final boolean visibleBeforeUpdate, final boolean visibleAfterUpdate,
final boolean expectRefresh) { final boolean expectRefresh) {
final Account account = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
// this sets up the appropriate result for Account#shouldBeVisibleInDirectory // this sets up the appropriate result for Account#shouldBeVisibleInDirectory
final Device device = new Device(Device.MASTER_ID, "device", "token", "salt", null, null, null, true, 1, final Device device = generateTestDevice(0);
new SignedPreKey(1, "key", "sig"), 0, 0,
"OWT", 0, new DeviceCapabilities());
account.addDevice(device); account.addDevice(device);
account.setDiscoverableByPhoneNumber(visibleBeforeUpdate); account.setDiscoverableByPhoneNumber(visibleBeforeUpdate);
@ -623,10 +621,8 @@ class AccountsManagerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) { void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) {
final Account account = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final Device device = new Device(Device.MASTER_ID, "device", "token", "salt", null, null, null, true, 1, final Device device = generateTestDevice(initialLastSeen);
new SignedPreKey(1, "key", "sig"), initialLastSeen, 0,
"OWT", 0, new DeviceCapabilities());
account.addDevice(device); account.addDevice(device);
accountsManager.updateDeviceLastSeen(account, device, updatedLastSeen); accountsManager.updateDeviceLastSeen(account, device, updatedLastSeen);
@ -654,7 +650,7 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
Account account = new Account(originalNumber, uuid, originalPni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber); account = accountsManager.changeNumber(account, targetNumber);
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
@ -670,7 +666,7 @@ class AccountsManagerTest {
void testChangePhoneNumberSameNumber() throws InterruptedException { void testChangePhoneNumberSameNumber() throws InterruptedException {
final String number = "+14152222222"; final String number = "+14152222222";
Account account = new Account(number, UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, number); account = accountsManager.changeNumber(account, number);
assertEquals(number, account.getNumber()); assertEquals(number, account.getNumber());
@ -691,10 +687,10 @@ class AccountsManagerTest {
final UUID originalPni = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final Account existingAccount = new Account(targetNumber, existingAccountUuid, targetPni, new HashSet<>(), new byte[16]); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
Account account = new Account(originalNumber, uuid, originalPni, new HashSet<>(), new byte[16]); Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, new ArrayList<>(), new byte[16]);
account = accountsManager.changeNumber(account, targetNumber); account = accountsManager.changeNumber(account, targetNumber);
assertEquals(targetNumber, account.getNumber()); assertEquals(targetNumber, account.getNumber());
@ -713,14 +709,14 @@ class AccountsManagerTest {
final String targetNumber = "+14153333333"; final String targetNumber = "+14153333333";
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final Account account = new Account(originalNumber, uuid, UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID()))); assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setNumber(targetNumber, UUID.randomUUID())));
} }
@Test @Test
void testSetUsername() throws UsernameNotAvailableException { void testSetUsername() throws UsernameNotAvailableException {
final Account account = new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test"; final String username = "test";
assertDoesNotThrow(() -> accountsManager.setUsername(account, username)); assertDoesNotThrow(() -> accountsManager.setUsername(account, username));
@ -729,7 +725,7 @@ class AccountsManagerTest {
@Test @Test
void testSetUsernameSameUsername() throws UsernameNotAvailableException { void testSetUsernameSameUsername() throws UsernameNotAvailableException {
final Account account = new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test"; final String username = "test";
account.setUsername(username); account.setUsername(username);
@ -739,7 +735,7 @@ class AccountsManagerTest {
@Test @Test
void testSetUsernameNotAvailable() throws UsernameNotAvailableException { void testSetUsernameNotAvailable() throws UsernameNotAvailableException {
final Account account = new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test"; final String username = "test";
doThrow(new UsernameNotAvailableException()).when(accounts).setUsername(account, username); doThrow(new UsernameNotAvailableException()).when(accounts).setUsername(account, username);
@ -755,7 +751,7 @@ class AccountsManagerTest {
final String username = "reserved"; final String username = "reserved";
when(reservedUsernames.isReserved(eq(username), any())).thenReturn(true); when(reservedUsernames.isReserved(eq(username), any())).thenReturn(true);
final Account account = new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, username)); assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, username));
assertTrue(account.getUsername().isEmpty()); assertTrue(account.getUsername().isEmpty());
@ -763,8 +759,18 @@ class AccountsManagerTest {
@Test @Test
void testSetUsernameViaUpdate() { void testSetUsernameViaUpdate() {
final Account account = new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>(), new byte[16]); final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsername("test"))); assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsername("test")));
} }
private static Device generateTestDevice(final long lastSeen) {
final Device device = new Device();
device.setId(Device.MASTER_ID);
device.setFetchesMessages(true);
device.setSignedPreKey(new SignedPreKey(1, "key", "sig"));
device.setLastSeen(lastSeen);
return device;
}
} }

View File

@ -21,12 +21,10 @@ import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Random; import java.util.Random;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -41,7 +39,8 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
@ -153,7 +152,7 @@ class AccountsTest {
@Test @Test
void testStore() { void testStore() {
Device device = generateDevice(1); Device device = generateDevice(1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account); boolean freshUser = accounts.create(account);
@ -173,11 +172,8 @@ class AccountsTest {
@Test @Test
void testStoreMulti() { void testStoreMulti() {
Set<Device> devices = new HashSet<>(); final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
devices.add(generateDevice(1)); final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
devices.add(generateDevice(2));
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
accounts.create(account); accounts.create(account);
@ -189,17 +185,13 @@ class AccountsTest {
@Test @Test
void testRetrieve() { void testRetrieve() {
Set<Device> devicesFirst = new HashSet<>(); final List<Device> devicesFirst = List.of(generateDevice(1), generateDevice(2));
devicesFirst.add(generateDevice(1));
devicesFirst.add(generateDevice(2));
UUID uuidFirst = UUID.randomUUID(); UUID uuidFirst = UUID.randomUUID();
UUID pniFirst = UUID.randomUUID(); UUID pniFirst = UUID.randomUUID();
Account accountFirst = generateAccount("+14151112222", uuidFirst, pniFirst, devicesFirst); Account accountFirst = generateAccount("+14151112222", uuidFirst, pniFirst, devicesFirst);
Set<Device> devicesSecond = new HashSet<>(); final List<Device> devicesSecond = List.of(generateDevice(1), generateDevice(2));
devicesSecond.add(generateDevice(1));
devicesSecond.add(generateDevice(2));
UUID uuidSecond = UUID.randomUUID(); UUID uuidSecond = UUID.randomUUID();
UUID pniSecond = UUID.randomUUID(); UUID pniSecond = UUID.randomUUID();
@ -238,12 +230,8 @@ class AccountsTest {
@Test @Test
void testRetrieveNoPni() throws JsonProcessingException { void testRetrieveNoPni() throws JsonProcessingException {
final Set<Device> devices = new HashSet<>(); final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
devices.add(generateDevice(1));
devices.add(generateDevice(2));
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final Account account = generateAccount("+14151112222", uuid, null, devices); final Account account = generateAccount("+14151112222", uuid, null, devices);
// Accounts#create enforces that newly-created accounts have a PNI, so we need to make a bit of an end-run around it // Accounts#create enforces that newly-created accounts have a PNI, so we need to make a bit of an end-run around it
@ -303,7 +291,7 @@ class AccountsTest {
Device device = generateDevice(1); Device device = generateDevice(1);
UUID firstUuid = UUID.randomUUID(); UUID firstUuid = UUID.randomUUID();
UUID firstPni = UUID.randomUUID(); UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, Collections.singleton(device)); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
accounts.create(account); accounts.create(account);
@ -317,7 +305,7 @@ class AccountsTest {
UUID secondUuid = UUID.randomUUID(); UUID secondUuid = UUID.randomUUID();
device = generateDevice(1); device = generateDevice(1);
account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), Collections.singleton(device)); account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device));
final boolean freshUser = accounts.create(account); final boolean freshUser = accounts.create(account);
assertThat(freshUser).isFalse(); assertThat(freshUser).isFalse();
@ -327,7 +315,7 @@ class AccountsTest {
assertPhoneNumberIdentifierConstraintExists(firstPni, firstUuid); assertPhoneNumberIdentifierConstraintExists(firstPni, firstUuid);
device = generateDevice(1); device = generateDevice(1);
Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), Collections.singleton(device)); Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.create(invalidAccount)); assertThatThrownBy(() -> accounts.create(invalidAccount));
} }
@ -335,7 +323,7 @@ class AccountsTest {
@Test @Test
void testUpdate() { void testUpdate() {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); accounts.create(account);
@ -360,7 +348,7 @@ class AccountsTest {
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), account, true);
device = generateDevice(1); device = generateDevice(1);
Account unknownAccount = generateAccount("+14151113333", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); Account unknownAccount = generateAccount("+14151113333", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class); assertThatThrownBy(() -> accounts.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
@ -454,10 +442,10 @@ class AccountsTest {
void testDelete() { void testDelete() {
final Device deletedDevice = generateDevice(1); final Device deletedDevice = generateDevice(1);
final Account deletedAccount = generateAccount("+14151112222", UUID.randomUUID(), final Account deletedAccount = generateAccount("+14151112222", UUID.randomUUID(),
UUID.randomUUID(), Collections.singleton(deletedDevice)); UUID.randomUUID(), List.of(deletedDevice));
final Device retainedDevice = generateDevice(1); final Device retainedDevice = generateDevice(1);
final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(), final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(),
UUID.randomUUID(), Collections.singleton(retainedDevice)); UUID.randomUUID(), List.of(retainedDevice));
accounts.create(deletedAccount); accounts.create(deletedAccount);
accounts.create(retainedAccount); accounts.create(retainedAccount);
@ -482,7 +470,7 @@ class AccountsTest {
{ {
final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(), final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(),
UUID.randomUUID(), Collections.singleton(generateDevice(1))); UUID.randomUUID(), List.of(generateDevice(1)));
final boolean freshUser = accounts.create(recreatedAccount); final boolean freshUser = accounts.create(recreatedAccount);
@ -499,7 +487,7 @@ class AccountsTest {
@Test @Test
void testMissing() { void testMissing() {
Device device = generateDevice (1 ); Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); accounts.create(account);
@ -567,7 +555,7 @@ class AccountsTest {
@Test @Test
void testCanonicallyDiscoverableSet() { void testCanonicallyDiscoverableSet() {
Device device = generateDevice(1); Device device = generateDevice(1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
account.setDiscoverableByPhoneNumber(false); account.setDiscoverableByPhoneNumber(false);
accounts.create(account); accounts.create(account);
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), account, false); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), account, false);
@ -588,7 +576,7 @@ class AccountsTest {
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final Device device = generateDevice(1); final Device device = generateDevice(1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, Collections.singleton(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account); accounts.create(account);
@ -634,10 +622,10 @@ class AccountsTest {
final UUID targetPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID();
final Device existingDevice = generateDevice(1); final Device existingDevice = generateDevice(1);
final Account existingAccount = generateAccount(targetNumber, UUID.randomUUID(), targetPni, Collections.singleton(existingDevice)); final Account existingAccount = generateAccount(targetNumber, UUID.randomUUID(), targetPni, List.of(existingDevice));
final Device device = generateDevice(1); final Device device = generateDevice(1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, Collections.singleton(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account); accounts.create(account);
accounts.create(existingAccount); accounts.create(existingAccount);
@ -656,7 +644,7 @@ class AccountsTest {
final String targetNumber = "+14151113333"; final String targetNumber = "+14151113333";
final Device device = generateDevice(1); final Device device = generateDevice(1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); accounts.create(account);
@ -952,31 +940,20 @@ class AccountsTest {
} }
private Device generateDevice(long id) { private Device generateDevice(long id) {
Random random = new Random(System.currentTimeMillis()); return DevicesHelper.createDevice(id);
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(),
"testSignature-" + random.nextInt());
return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(),
"testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(),
random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(),
"testUserAgent-" + random.nextInt(), 0,
new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean()));
} }
private Account generateAccount(String number, UUID uuid, final UUID pni) { private Account generateAccount(String number, UUID uuid, final UUID pni) {
Device device = generateDevice(1); Device device = generateDevice(1);
return generateAccount(number, uuid, pni, Collections.singleton(device)); return generateAccount(number, uuid, pni, List.of(device));
} }
private Account generateAccount(String number, UUID uuid, final UUID pni, Set<Device> devices) { private Account generateAccount(String number, UUID uuid, final UUID pni, List<Device> devices) {
byte[] unidentifiedAccessKey = new byte[16]; byte[] unidentifiedAccessKey = new byte[16];
Random random = new Random(System.currentTimeMillis()); Random random = new Random(System.currentTimeMillis());
Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255)); Arrays.fill(unidentifiedAccessKey, (byte)random.nextInt(255));
return new Account(number, uuid, pni, devices, unidentifiedAccessKey); return AccountsHelper.generateTestAccount(number, uuid, pni, devices, unidentifiedAccessKey);
} }
private void assertPhoneNumberConstraintExists(final String number, final UUID uuid) { private void assertPhoneNumberConstraintExists(final String number, final UUID uuid) {

View File

@ -4,6 +4,18 @@
*/ */
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -13,20 +25,6 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ChangeNumberManagerTest { public class ChangeNumberManagerTest {
private static AccountsManager accountsManager = mock(AccountsManager.class); private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessageSender messageSender = mock(MessageSender.class); private static MessageSender messageSender = mock(MessageSender.class);
@ -40,7 +38,7 @@ public class ChangeNumberManagerTest {
final String number = invocation.getArgument(1, String.class); final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
final Set<Device> devices = account.getDevices(); final List<Device> devices = account.getDevices();
final Account updatedAccount = mock(Account.class); final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getUuid()).thenReturn(uuid);

View File

@ -21,9 +21,17 @@ class DeviceTest {
@MethodSource @MethodSource
void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId, void testIsEnabled(final boolean master, final boolean fetchesMessages, final String apnId, final String gcmId,
final SignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) { final SignedPreKey signedPreKey, final Duration timeSinceLastSeen, final boolean expectEnabled) {
final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis(); final long lastSeen = System.currentTimeMillis() - timeSinceLastSeen.toMillis();
final Device device = new Device(master ? 1 : 2, "test", "auth-token", "salt", gcmId, apnId, null, fetchesMessages,
1, signedPreKey, lastSeen, lastSeen, "user-agent", 0, null); final Device device = new Device();
device.setId(master ? Device.MASTER_ID : Device.MASTER_ID + 1);
device.setFetchesMessages(fetchesMessages);
device.setApnId(apnId);
device.setGcmId(gcmId);
device.setSignedPreKey(signedPreKey);
device.setCreated(lastSeen);
device.setLastSeen(lastSeen);
assertEquals(expectEnabled, device.isEnabled()); assertEquals(expectEnabled, device.isEnabled());
} }
@ -73,8 +81,11 @@ class DeviceTest {
final Device.DeviceCapabilities capabilities = new Device.DeviceCapabilities(gv2Capability, gv2_2Capability, final Device.DeviceCapabilities capabilities = new Device.DeviceCapabilities(gv2Capability, gv2_2Capability,
gv2_3Capability, false, false, false, gv2_3Capability, false, false, false,
false, false, false, false, false, false); false, false, false, false, false, false);
final Device device = new Device(master ? 1 : 2, "test", "auth-token", "salt",
null, apnId, null, false, 1, null, 0, 0, "user-agent", 0, capabilities); final Device device = new Device();
device.setId(master ? Device.MASTER_ID : Device.MASTER_ID + 1);
device.setApnId(apnId);
device.setCapabilities(capabilities);
assertEquals(expectGv2Supported, device.isGroupsV2Supported()); assertEquals(expectGv2Supported, device.isGroupsV2Supported());
} }

View File

@ -31,7 +31,6 @@ import java.io.IOException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Duration; import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
@ -259,7 +258,7 @@ class AccountControllerTest {
final String number = invocation.getArgument(1, String.class); final String number = invocation.getArgument(1, String.class);
final UUID uuid = account.getUuid(); final UUID uuid = account.getUuid();
final Set<Device> devices = account.getDevices(); final List<Device> devices = account.getDevices();
final Account updatedAccount = mock(Account.class); final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getUuid()).thenReturn(uuid);
@ -1557,7 +1556,7 @@ class AccountControllerTest {
Device device3 = mock(Device.class); Device device3 = mock(Device.class);
when(device3.getId()).thenReturn(3L); when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true); when(device3.isEnabled()).thenReturn(true);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
new StoredVerificationCode(code, System.currentTimeMillis(), "push", null))); new StoredVerificationCode(code, System.currentTimeMillis(), "push", null)));
@ -1591,7 +1590,7 @@ class AccountControllerTest {
when(device3.getId()).thenReturn(3L); when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true); when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3); when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(
@ -1633,7 +1632,7 @@ class AccountControllerTest {
when(device3.getId()).thenReturn(3L); when(device3.getId()).thenReturn(3L);
when(device3.isEnabled()).thenReturn(true); when(device3.isEnabled()).thenReturn(true);
when(device3.getRegistrationId()).thenReturn(3); when(device3.getRegistrationId()).thenReturn(3);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(AuthHelper.VALID_DEVICE, device2, device3)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(AuthHelper.VALID_DEVICE, device2, device3));
when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2)); when(AuthHelper.VALID_ACCOUNT.getDevice(2L)).thenReturn(Optional.of(device2));
when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3)); when(AuthHelper.VALID_ACCOUNT.getDevice(3L)).thenReturn(Optional.of(device3));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of( when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(

View File

@ -21,9 +21,9 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
@ -162,7 +162,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID); when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest() VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
@ -194,7 +194,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.MASTER_ID); when(existingDevice.getId()).thenReturn(Device.MASTER_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest() VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")

View File

@ -25,11 +25,9 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration; import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
@ -115,12 +113,7 @@ class KeysControllerTest {
final Device sampleDevice3 = mock(Device.class); final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class); final Device sampleDevice4 = mock(Device.class);
Set<Device> allDevices = new HashSet<>() {{ final List<Device> allDevices = List.of(sampleDevice, sampleDevice2, sampleDevice3, sampleDevice4);
add(sampleDevice);
add(sampleDevice2);
add(sampleDevice3);
add(sampleDevice4);
}};
AccountsHelper.setupMockUpdate(accounts); AccountsHelper.setupMockUpdate(accounts);
@ -351,7 +344,7 @@ class KeysControllerTest {
@Test @Test
void testNoDevices() { void testNoDevices() {
when(existsAccount.getDevices()).thenReturn(Collections.emptySet()); when(existsAccount.getDevices()).thenReturn(Collections.emptyList());
Response result = resources.getJerseyTest() Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .target(String.format("/v2/keys/%s/*", EXISTS_UUID))

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.storage;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -19,8 +20,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.List;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -29,6 +29,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge; import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
class AccountTest { class AccountTest {
@ -164,14 +165,17 @@ class AccountTest {
new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false)); new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false));
when(pniIncapableExpiredDevice.isEnabled()).thenReturn(false); when(pniIncapableExpiredDevice.isEnabled()).thenReturn(false);
when(storiesCapableDevice.getId()).thenReturn(1L);
when(storiesCapableDevice.getCapabilities()).thenReturn( when(storiesCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, true, false)); new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, true, false));
when(storiesCapableDevice.isEnabled()).thenReturn(true); when(storiesCapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(2L);
when(storiesIncapableDevice.getCapabilities()).thenReturn( when(storiesIncapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false)); new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false));
when(storiesIncapableDevice.isEnabled()).thenReturn(true); when(storiesIncapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(3L);
when(storiesIncapableExpiredDevice.getCapabilities()).thenReturn( when(storiesIncapableExpiredDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false)); new DeviceCapabilities(true, true, true, true, true, true, true, false, false, false, false, false));
when(storiesIncapableExpiredDevice.isEnabled()).thenReturn(false); when(storiesIncapableExpiredDevice.isEnabled()).thenReturn(false);
@ -204,33 +208,19 @@ class AccountTest {
when(disabledMasterDevice.getId()).thenReturn(1L); when(disabledMasterDevice.getId()).thenReturn(1L);
when(disabledLinkedDevice.getId()).thenReturn(2L); when(disabledLinkedDevice.getId()).thenReturn(2L);
assertTrue( new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), Set.of(enabledMasterDevice), new byte[0]).isEnabled()); assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledMasterDevice)).isEnabled());
assertTrue( new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledMasterDevice, enabledLinkedDevice)).isEnabled());
Set.of(enabledMasterDevice, enabledLinkedDevice), new byte[0]).isEnabled()); assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledMasterDevice, disabledLinkedDevice)).isEnabled());
assertTrue( new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledMasterDevice)).isEnabled());
Set.of(enabledMasterDevice, disabledLinkedDevice), new byte[0]).isEnabled()); assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledMasterDevice, enabledLinkedDevice)).isEnabled());
assertFalse(new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), Set.of(disabledMasterDevice), new byte[0]).isEnabled()); assertFalse(AccountsHelper.generateTestAccount("+14151234567", List.of(disabledMasterDevice, disabledLinkedDevice)).isEnabled());
assertFalse(new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(),
Set.of(disabledMasterDevice, enabledLinkedDevice), new byte[0]).isEnabled());
assertFalse(new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(),
Set.of(disabledMasterDevice, disabledLinkedDevice), new byte[0]).isEnabled());
} }
@Test @Test
void testCapabilities() { void testCapabilities() {
Account uuidCapable = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<Device>() {{ final Account uuidCapable = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(gv2CapableDevice), "1234".getBytes());
add(gv2CapableDevice); final Account uuidIncapable = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(gv2CapableDevice, gv2IncapableDevice), "1234".getBytes());
}}, "1234".getBytes()); final Account uuidCapableWithExpiredIncapable = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(gv2CapableDevice, gv2IncapableExpiredDevice), "1234".getBytes());
Account uuidIncapable = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<Device>() {{
add(gv2CapableDevice);
add(gv2IncapableDevice);
}}, "1234".getBytes());
Account uuidCapableWithExpiredIncapable = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<Device>() {{
add(gv2CapableDevice);
add(gv2IncapableExpiredDevice);
}}, "1234".getBytes());
assertTrue(uuidCapable.isGroupsV2Supported()); assertTrue(uuidCapable.isGroupsV2Supported());
assertFalse(uuidIncapable.isGroupsV2Supported()); assertFalse(uuidIncapable.isGroupsV2Supported());
@ -263,23 +253,20 @@ class AccountTest {
{ {
final Account transferableMasterAccount = final Account transferableMasterAccount =
new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(transferCapableMasterDevice), "1234".getBytes()); AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(transferCapableMasterDevice), "1234".getBytes());
assertTrue(transferableMasterAccount.isTransferSupported()); assertTrue(transferableMasterAccount.isTransferSupported());
} }
{ {
final Account nonTransferableMasterAccount = final Account nonTransferableMasterAccount =
new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(nonTransferCapableMasterDevice), "1234".getBytes()); AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(nonTransferCapableMasterDevice), "1234".getBytes());
assertFalse(nonTransferableMasterAccount.isTransferSupported()); assertFalse(nonTransferableMasterAccount.isTransferSupported());
} }
{ {
final Account transferableLinkedAccount = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), new HashSet<>() {{ final Account transferableLinkedAccount = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(nonTransferCapableMasterDevice, transferCapableLinkedDevice), "1234".getBytes());
add(nonTransferCapableMasterDevice);
add(transferCapableLinkedDevice);
}}, "1234".getBytes());
assertFalse(transferableLinkedAccount.isTransferSupported()); assertFalse(transferableLinkedAccount.isTransferSupported());
} }
@ -287,7 +274,7 @@ class AccountTest {
@Test @Test
void testDiscoverableByPhoneNumber() { void testDiscoverableByPhoneNumber() {
final Account account = new Account("+14152222222", UUID.randomUUID(), UUID.randomUUID(), Collections.singleton(recentMasterDevice), final Account account = AccountsHelper.generateTestAccount("+14152222222", UUID.randomUUID(), UUID.randomUUID(), List.of(recentMasterDevice),
"1234".getBytes()); "1234".getBytes());
assertTrue(account.isDiscoverableByPhoneNumber(), assertTrue(account.isDiscoverableByPhoneNumber(),
@ -302,111 +289,111 @@ class AccountTest {
@Test @Test
void isGroupsV2Supported() { void isGroupsV2Supported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), Set.of(gv2CapableDevice), assertTrue(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), List.of(gv2CapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertTrue(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(gv2CapableDevice, gv2IncapableExpiredDevice), List.of(gv2CapableDevice, gv2IncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
assertFalse(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertFalse(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(gv2CapableDevice, gv2IncapableDevice), List.of(gv2CapableDevice, gv2IncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported()); "1234".getBytes(StandardCharsets.UTF_8)).isGroupsV2Supported());
} }
@Test @Test
void isGv1MigrationSupported() { void isGv1MigrationSupported() {
assertTrue(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), Set.of(gv1MigrationCapableDevice), assertTrue(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), List.of(gv1MigrationCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertFalse( assertFalse(
new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice), List.of(gv1MigrationCapableDevice, gv1MigrationIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported()); "1234".getBytes(StandardCharsets.UTF_8)).isGv1MigrationSupported());
assertTrue(new Account("+18005551234", UUID.randomUUID(), assertTrue(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8)) UUID.randomUUID(), List.of(gv1MigrationCapableDevice, gv1MigrationIncapableExpiredDevice), "1234".getBytes(StandardCharsets.UTF_8))
.isGv1MigrationSupported()); .isGv1MigrationSupported());
} }
@Test @Test
void isSenderKeySupported() { void isSenderKeySupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), Set.of(senderKeyCapableDevice), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), List.of(senderKeyCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(senderKeyCapableDevice, senderKeyIncapableDevice), List.of(senderKeyCapableDevice, senderKeyIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isFalse(); "1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(senderKeyCapableDevice, senderKeyIncapableExpiredDevice), UUID.randomUUID(), List.of(senderKeyCapableDevice, senderKeyIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isSenderKeySupported()).isTrue();
} }
@Test @Test
void isAnnouncementGroupSupported() { void isAnnouncementGroupSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(announcementGroupCapableDevice), UUID.randomUUID(), List.of(announcementGroupCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(announcementGroupCapableDevice, announcementGroupIncapableDevice), UUID.randomUUID(), List.of(announcementGroupCapableDevice, announcementGroupIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isFalse(); "1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(announcementGroupCapableDevice, announcementGroupIncapableExpiredDevice), UUID.randomUUID(), List.of(announcementGroupCapableDevice, announcementGroupIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isAnnouncementGroupSupported()).isTrue();
} }
@Test @Test
void isChangeNumberSupported() { void isChangeNumberSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(changeNumberCapableDevice), UUID.randomUUID(), List.of(changeNumberCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(changeNumberCapableDevice, changeNumberIncapableDevice), UUID.randomUUID(), List.of(changeNumberCapableDevice, changeNumberIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isFalse(); "1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(changeNumberCapableDevice, changeNumberIncapableExpiredDevice), UUID.randomUUID(), List.of(changeNumberCapableDevice, changeNumberIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isChangeNumberSupported()).isTrue();
} }
@Test @Test
void isPniSupported() { void isPniSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(pniCapableDevice), UUID.randomUUID(), List.of(pniCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(pniCapableDevice, pniIncapableDevice), UUID.randomUUID(), List.of(pniCapableDevice, pniIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isFalse(); "1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(pniCapableDevice, pniIncapableExpiredDevice), UUID.randomUUID(), List.of(pniCapableDevice, pniIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isPniSupported()).isTrue();
} }
@Test @Test
void isStoriesSupported() { void isStoriesSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(storiesCapableDevice), UUID.randomUUID(), List.of(storiesCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(storiesCapableDevice, storiesIncapableDevice), UUID.randomUUID(), List.of(storiesCapableDevice, storiesIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue();
// TODO stories capability // TODO stories capability
// "1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isFalse(); // "1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(),
UUID.randomUUID(), Set.of(storiesCapableDevice, storiesIncapableExpiredDevice), UUID.randomUUID(), List.of(storiesCapableDevice, storiesIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isStoriesSupported()).isTrue();
} }
@Test @Test
void isGiftBadgesSupported() { void isGiftBadgesSupported() {
assertThat(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(giftBadgesCapableDevice), List.of(giftBadgesCapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isTrue();
assertThat(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(giftBadgesCapableDevice, giftBadgesIncapableDevice), List.of(giftBadgesCapableDevice, giftBadgesIncapableDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isFalse(); "1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isFalse();
assertThat(new Account("+18005551234", UUID.randomUUID(), UUID.randomUUID(), assertThat(AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(),
Set.of(giftBadgesCapableDevice, giftBadgesIncapableExpiredDevice), List.of(giftBadgesCapableDevice, giftBadgesIncapableExpiredDevice),
"1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isTrue(); "1234".getBytes(StandardCharsets.UTF_8)).isGiftBadgesSupported()).isTrue();
} }
@Test @Test
void stale() { void stale() {
final Account account = new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), Collections.emptySet(), final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), Collections.emptyList(),
new byte[0]); new byte[0]);
assertDoesNotThrow(account::getNumber); assertDoesNotThrow(account::getNumber);
@ -420,10 +407,9 @@ class AccountTest {
@Test @Test
void getNextDeviceId() { void getNextDeviceId() {
final Set<Device> devices = new HashSet<>(); final List<Device> devices = List.of(createDevice(Device.MASTER_ID));
devices.add(createDevice(Device.MASTER_ID));
final Account account = new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), devices, new byte[0]); final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), devices, new byte[0]);
assertThat(account.getNextDeviceId()).isEqualTo(2L); assertThat(account.getNextDeviceId()).isEqualTo(2L);
@ -442,9 +428,22 @@ class AccountTest {
assertThat(account.getNextDeviceId()).isEqualTo(2L); assertThat(account.getNextDeviceId()).isEqualTo(2L);
} }
@Test
void replaceDevice() {
final Device firstDevice = createDevice(Device.MASTER_ID);
final Device secondDevice = createDevice(Device.MASTER_ID);
final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(firstDevice), new byte[0]);
assertEquals(List.of(firstDevice), account.getDevices());
account.addDevice(secondDevice);
assertEquals(List.of(secondDevice), account.getDevices());
}
@Test @Test
void addAndRemoveBadges() { void addAndRemoveBadges() {
final Account account = new Account("+14151234567", UUID.randomUUID(), UUID.randomUUID(), Set.of(createDevice(Device.MASTER_ID)), new byte[0]); final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), List.of(createDevice(Device.MASTER_ID)), new byte[0]);
final Clock clock = mock(Clock.class); final Clock clock = mock(Clock.class);
when(clock.instant()).thenReturn(Instant.ofEpochSecond(40)); when(clock.instant()).thenReturn(Instant.ofEpochSecond(40));

View File

@ -78,11 +78,11 @@ class PushFeedbackProcessorTest {
when(stillActiveDevice.getLastSeen()).thenReturn(Util.todayInMillis()); when(stillActiveDevice.getLastSeen()).thenReturn(Util.todayInMillis());
when(stillActiveDevice.isEnabled()).thenReturn(true); when(stillActiveDevice.isEnabled()).thenReturn(true);
when(uninstalledAccount.getDevices()).thenReturn(Set.of(uninstalledDevice)); when(uninstalledAccount.getDevices()).thenReturn(List.of(uninstalledDevice));
when(mixedAccount.getDevices()).thenReturn(Set.of(installedDevice, uninstalledDeviceTwo)); when(mixedAccount.getDevices()).thenReturn(List.of(installedDevice, uninstalledDeviceTwo));
when(freshAccount.getDevices()).thenReturn(Set.of(recentUninstalledDevice)); when(freshAccount.getDevices()).thenReturn(List.of(recentUninstalledDevice));
when(cleanAccount.getDevices()).thenReturn(Set.of(installedDeviceTwo)); when(cleanAccount.getDevices()).thenReturn(List.of(installedDeviceTwo));
when(stillActiveAccount.getDevices()).thenReturn(Set.of(stillActiveDevice)); when(stillActiveAccount.getDevices()).thenReturn(List.of(stillActiveDevice));
when(mixedAccount.getUuid()).thenReturn(UUID.randomUUID()); when(mixedAccount.getUuid()).thenReturn(UUID.randomUUID());
when(freshAccount.getUuid()).thenReturn(UUID.randomUUID()); when(freshAccount.getUuid()).thenReturn(UUID.randomUUID());

View File

@ -14,6 +14,7 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException; import java.io.IOException;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -26,6 +27,20 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper { public class AccountsHelper {
public static Account generateTestAccount(String number, List<Device> devices) {
return generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, null);
}
public static Account generateTestAccount(String number, UUID uuid, final UUID phoneNumberIdentifier, List<Device> devices, byte[] unidentifiedAccessKey) {
final Account account = new Account();
account.setNumber(number, phoneNumberIdentifier);
account.setUuid(uuid);
devices.forEach(account::addDevice);
account.setUnidentifiedAccessKey(unidentifiedAccessKey);
return account;
}
public static void setupMockUpdate(final AccountsManager mockAccountsManager) { public static void setupMockUpdate(final AccountsManager mockAccountsManager) {
setupMockUpdate(mockAccountsManager, true); setupMockUpdate(mockAccountsManager, true);
} }

View File

@ -15,8 +15,14 @@ public class DevicesHelper {
private static final Random RANDOM = new Random(); private static final Random RANDOM = new Random();
public static Device createDevice(final long deviceId) { public static Device createDevice(final long deviceId) {
final Device device = new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "OWT", 0, return createDevice(deviceId, 0);
null); }
public static Device createDevice(final long deviceId, final long lastSeen) {
final Device device = new Device();
device.setId(deviceId);
device.setLastSeen(lastSeen);
device.setUserAgent("OWT");
setEnabled(device, true); setEnabled(device, true);

View File

@ -11,8 +11,10 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -120,7 +122,7 @@ class MessageValidationTest {
if (deviceIdAndEnabled.length % 2 != 0) { if (deviceIdAndEnabled.length % 2 != 0) {
throw new IllegalArgumentException("invalid number of arguments specified; must be even"); throw new IllegalArgumentException("invalid number of arguments specified; must be even");
} }
final Set<Device> devices = new HashSet<>(deviceIdAndEnabled.length / 2); final List<Device> devices = new ArrayList<>(deviceIdAndEnabled.length / 2);
for (int i = 0; i < deviceIdAndEnabled.length; i+=2) { for (int i = 0; i < deviceIdAndEnabled.length; i+=2) {
if (!(deviceIdAndEnabled[i] instanceof Long)) { if (!(deviceIdAndEnabled[i] instanceof Long)) {
throw new IllegalArgumentException("device id is not instance of long at index " + i); throw new IllegalArgumentException("device id is not instance of long at index " + i);

View File

@ -30,11 +30,9 @@ import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -166,9 +164,7 @@ class WebSocketConnectionTest {
final Device sender1device = mock(Device.class); final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<>() {{ List<Device> sender1devices = List.of(sender1device);
add(sender1device);
}};
Account sender1 = mock(Account.class); Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices); when(sender1.getDevices()).thenReturn(sender1devices);
@ -323,9 +319,7 @@ class WebSocketConnectionTest {
final Device sender1device = mock(Device.class); final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<Device>() {{ List<Device> sender1devices = List.of(sender1device);
add(sender1device);
}};
Account sender1 = mock(Account.class); Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices); when(sender1.getDevices()).thenReturn(sender1devices);
@ -705,9 +699,7 @@ class WebSocketConnectionTest {
final Device sender1device = mock(Device.class); final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<>() {{ List<Device> sender1devices = List.of(sender1device);
add(sender1device);
}};
Account sender1 = mock(Account.class); Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices); when(sender1.getDevices()).thenReturn(sender1devices);
@ -780,9 +772,7 @@ class WebSocketConnectionTest {
final Device sender1device = mock(Device.class); final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<>() {{ List<Device> sender1devices = List.of(sender1device);
add(sender1device);
}};
Account sender1 = mock(Account.class); Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices); when(sender1.getDevices()).thenReturn(sender1devices);