From 6a428b4da92db882d0dd8d6f4d38c71e71dc6adb Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Tue, 24 Oct 2023 18:58:13 -0500 Subject: [PATCH] Convert Device.id from `long` to `byte` --- .../org/signal/integration/Operations.java | 2 +- .../org/signal/integration/TestDevice.java | 8 +- .../java/org/signal/integration/TestUser.java | 4 +- ...hEnablementRefreshRequirementProvider.java | 12 +- .../auth/BaseAccountAuthenticator.java | 10 +- .../auth/BasicAuthorizationHeader.java | 8 +- .../textsecuregcm/auth/OptionalAccess.java | 2 +- ...umberChangeRefreshRequirementProvider.java | 2 +- .../RegistrationLockVerificationManager.java | 2 +- .../WebsocketRefreshRequirementProvider.java | 2 +- .../auth/grpc/AuthenticatedDevice.java | 2 +- .../auth/grpc/AuthenticationUtil.java | 4 +- .../controllers/AccountController.java | 2 +- .../controllers/DeviceController.java | 4 +- .../controllers/KeysController.java | 2 +- .../controllers/MessageController.java | 10 +- .../MismatchedDevicesException.java | 10 +- .../controllers/ProvisioningController.java | 2 +- .../controllers/StaleDevicesException.java | 7 +- .../entities/AccountDataReportResponse.java | 2 +- .../entities/ChangeNumberRequest.java | 6 +- .../entities/DeviceResponse.java | 6 +- .../entities/IncomingMessage.java | 4 +- .../entities/MismatchedDevices.java | 4 +- .../entities/MultiRecipientMessage.java | 2 +- ...eNumberIdentityKeyDistributionRequest.java | 12 +- .../entities/PreKeyResponse.java | 2 +- .../entities/PreKeyResponseItem.java | 7 +- .../textsecuregcm/entities/StaleDevices.java | 2 +- .../entities/UnregisteredEvent.java | 51 --- .../entities/UnregisteredEventList.java | 22 -- .../textsecuregcm/grpc/DeviceIdUtil.java | 18 + .../grpc/DevicesGrpcService.java | 9 +- .../grpc/KeysAnonymousGrpcService.java | 4 +- .../textsecuregcm/grpc/KeysGrpcHelper.java | 7 +- .../textsecuregcm/grpc/KeysGrpcService.java | 6 +- .../MultiRecipientMessageProvider.java | 9 +- .../push/ApnPushNotificationScheduler.java | 6 +- .../push/ClientPresenceManager.java | 19 +- .../push/PushLatencyManager.java | 8 +- .../push/PushNotificationManager.java | 2 +- .../textsecuregcm/push/ReceiptSender.java | 2 +- .../textsecuregcm/storage/Account.java | 12 +- .../storage/AccountsManager.java | 29 +- .../storage/ChangeNumberManager.java | 12 +- .../textsecuregcm/storage/Device.java | 17 +- .../storage/DeviceIdDeserializer.java | 41 +++ .../textsecuregcm/storage/KeysManager.java | 38 ++- .../storage/MessagePersister.java | 4 +- .../textsecuregcm/storage/MessagesCache.java | 54 ++- .../storage/MessagesDynamoDb.java | 20 +- .../storage/MessagesManager.java | 18 +- .../RefreshingAccountAndDeviceSupplier.java | 2 +- .../RepeatedUseECSignedPreKeyStore.java | 4 +- .../RepeatedUseKEMSignedPreKeyStore.java | 2 +- .../storage/RepeatedUseSignedPreKeyStore.java | 23 +- .../storage/SingleUseECPreKeyStore.java | 2 +- .../storage/SingleUseKEMPreKeyStore.java | 2 +- .../storage/SingleUsePreKeyStore.java | 20 +- .../util/DestinationDeviceValidator.java | 19 +- .../websocket/ProvisioningAddress.java | 4 +- .../websocket/WebsocketAddress.java | 8 +- .../MigrateSignedECPreKeysCommand.java | 2 +- .../workers/UnlinkDeviceCommand.java | 6 +- .../main/proto/org/signal/chat/device.proto | 4 +- .../src/main/proto/org/signal/chat/keys.proto | 5 +- ...blementRefreshRequirementProviderTest.java | 58 ++-- .../auth/BaseAccountAuthenticatorTest.java | 21 +- .../auth/CertificateGeneratorTest.java | 10 +- .../auth/OptionalAccessTest.java | 4 +- .../grpc/MockAuthenticationInterceptor.java | 4 +- .../controllers/AccountControllerTest.java | 8 +- .../controllers/AccountControllerV2Test.java | 16 +- .../controllers/DeviceControllerTest.java | 28 +- .../controllers/KeysControllerTest.java | 312 ++++++++++-------- .../controllers/MessageControllerTest.java | 67 ++-- .../entities/OutgoingMessageEntityTest.java | 4 +- .../grpc/AccountsGrpcServiceTest.java | 6 +- .../grpc/DevicesGrpcServiceTest.java | 40 +-- .../textsecuregcm/grpc/GrpcTestUtils.java | 2 +- .../grpc/KeysAnonymousGrpcServiceTest.java | 8 +- .../grpc/KeysGrpcServiceTest.java | 73 ++-- .../grpc/SimpleBaseGrpcTest.java | 2 +- .../textsecuregcm/push/APNSenderTest.java | 2 +- .../ApnPushNotificationSchedulerTest.java | 9 +- .../push/ClientPresenceManagerTest.java | 34 +- .../textsecuregcm/push/MessageSenderTest.java | 8 +- .../push/ProvisioningManagerTest.java | 25 +- .../push/PushLatencyManagerTest.java | 2 +- .../textsecuregcm/storage/AccountTest.java | 59 ++-- ...ntsManagerChangeNumberIntegrationTest.java | 4 +- ...ConcurrentModificationIntegrationTest.java | 6 +- .../storage/AccountsManagerTest.java | 154 +++++---- .../textsecuregcm/storage/AccountsTest.java | 95 ++++-- .../storage/ChangeNumberManagerTest.java | 138 +++++--- .../storage/KeysManagerTest.java | 81 +++-- .../MessagePersisterIntegrationTest.java | 8 +- .../storage/MessagePersisterTest.java | 39 +-- .../storage/MessagesCacheTest.java | 18 +- .../storage/MessagesDynamoDbTest.java | 105 +++--- .../storage/MessagesManagerTest.java | 4 +- ...efreshingAccountAndDeviceSupplierTest.java | 2 +- .../RepeatedUseECSignedPreKeyStoreTest.java | 4 +- .../RepeatedUseSignedPreKeyStoreTest.java | 38 ++- .../storage/SingleUsePreKeyStoreTest.java | 25 +- .../tests/util/AccountsHelper.java | 5 +- .../textsecuregcm/tests/util/AuthHelper.java | 30 +- .../tests/util/DevicesHelper.java | 8 +- .../tests/util/MessageHelper.java | 2 +- .../util/DestinationDeviceValidatorTest.java | 102 +++--- .../WebSocketConnectionIntegrationTest.java | 2 +- .../websocket/WebSocketConnectionTest.java | 61 ++-- 112 files changed, 1292 insertions(+), 1094 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEvent.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEventList.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceIdDeserializer.java diff --git a/integration-tests/src/main/java/org/signal/integration/Operations.java b/integration-tests/src/main/java/org/signal/integration/Operations.java index e631bfd40..2eb48fd05 100644 --- a/integration-tests/src/main/java/org/signal/integration/Operations.java +++ b/integration-tests/src/main/java/org/signal/integration/Operations.java @@ -236,7 +236,7 @@ public final class Operations { return authorized(user, Device.PRIMARY_ID); } - public RequestBuilder authorized(final TestUser user, final long deviceId) { + public RequestBuilder authorized(final TestUser user, final byte deviceId) { final String username = "%s.%d".formatted(user.aciUuid().toString(), deviceId); return authorized(username, user.accountPassword()); } diff --git a/integration-tests/src/main/java/org/signal/integration/TestDevice.java b/integration-tests/src/main/java/org/signal/integration/TestDevice.java index 59c5de53a..174175f4e 100644 --- a/integration-tests/src/main/java/org/signal/integration/TestDevice.java +++ b/integration-tests/src/main/java/org/signal/integration/TestDevice.java @@ -16,13 +16,13 @@ import org.signal.libsignal.protocol.state.SignedPreKeyRecord; public class TestDevice { - private final long deviceId; + private final byte deviceId; private final Map> signedPreKeys = new ConcurrentHashMap<>(); public static TestDevice create( - final long deviceId, + final byte deviceId, final IdentityKeyPair aciIdentityKeyPair, final IdentityKeyPair pniIdentityKeyPair) { final TestDevice device = new TestDevice(deviceId); @@ -31,11 +31,11 @@ public class TestDevice { return device; } - public TestDevice(final long deviceId) { + public TestDevice(final byte deviceId) { this.deviceId = deviceId; } - public long deviceId() { + public byte deviceId() { return deviceId; } diff --git a/integration-tests/src/main/java/org/signal/integration/TestUser.java b/integration-tests/src/main/java/org/signal/integration/TestUser.java index 2b71a79ff..09df391fc 100644 --- a/integration-tests/src/main/java/org/signal/integration/TestUser.java +++ b/integration-tests/src/main/java/org/signal/integration/TestUser.java @@ -30,7 +30,7 @@ public class TestUser { private final IdentityKeyPair aciIdentityKey; - private final Map devices = new ConcurrentHashMap<>(); + private final Map devices = new ConcurrentHashMap<>(); private final byte[] unidentifiedAccessKey; @@ -147,7 +147,7 @@ public class TestUser { this.registrationPassword = registrationPassword; } - public PreKeySetPublicView preKeys(final long deviceId, final boolean pni) { + public PreKeySetPublicView preKeys(final byte deviceId, final boolean pni) { final IdentityKeyPair identity = pni ? pniIdentityKey : aciIdentityKey; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java index 1e3e5b22a..6661375fd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java @@ -47,7 +47,7 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres } @VisibleForTesting - static Map buildDevicesEnabledMap(final Account account) { + static Map buildDevicesEnabledMap(final Account account) { return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled)); } @@ -68,17 +68,17 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres } @Override - public List> handleRequestFinished(final RequestEvent requestEvent) { + public List> handleRequestFinished(final RequestEvent requestEvent) { // Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did // change or if a devices was added or removed, all devices must disconnect and reauthenticate. if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) { - @SuppressWarnings("unchecked") final Map initialDevicesEnabled = - (Map) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED); + @SuppressWarnings("unchecked") final Map initialDevicesEnabled = + (Map) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED); return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> { - final Set deviceIdsToDisplace; - final Map currentDevicesEnabled = buildDevicesEnabledMap(account); + final Set deviceIdsToDisplace; + final Map currentDevicesEnabled = buildDevicesEnabledMap(account); if (!initialDevicesEnabled.equals(currentDevicesEnabled)) { deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java index d1def356d..4adab6cb4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java @@ -52,9 +52,9 @@ public class BaseAccountAuthenticator { this.clock = clock; } - static Pair getIdentifierAndDeviceId(final String basicUsername) { + static Pair getIdentifierAndDeviceId(final String basicUsername) { final String identifier; - final long deviceId; + final byte deviceId; final int deviceIdSeparatorIndex = basicUsername.indexOf(DEVICE_ID_SEPARATOR); @@ -63,7 +63,7 @@ public class BaseAccountAuthenticator { deviceId = Device.PRIMARY_ID; } else { identifier = basicUsername.substring(0, deviceIdSeparatorIndex); - deviceId = Long.parseLong(basicUsername.substring(deviceIdSeparatorIndex + 1)); + deviceId = Byte.parseByte(basicUsername.substring(deviceIdSeparatorIndex + 1)); } return new Pair<>(identifier, deviceId); @@ -75,9 +75,9 @@ public class BaseAccountAuthenticator { try { final UUID accountUuid; - final long deviceId; + final byte deviceId; { - final Pair identifierAndDeviceId = getIdentifierAndDeviceId(basicCredentials.getUsername()); + final Pair identifierAndDeviceId = getIdentifierAndDeviceId(basicCredentials.getUsername()); accountUuid = UUID.fromString(identifierAndDeviceId.first()); deviceId = identifierAndDeviceId.second(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BasicAuthorizationHeader.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BasicAuthorizationHeader.java index 05e4f4f27..54f99fe70 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BasicAuthorizationHeader.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BasicAuthorizationHeader.java @@ -11,10 +11,10 @@ import org.whispersystems.textsecuregcm.util.Pair; public class BasicAuthorizationHeader { private final String username; - private final long deviceId; + private final byte deviceId; private final String password; - private BasicAuthorizationHeader(final String username, final long deviceId, final String password) { + private BasicAuthorizationHeader(final String username, final byte deviceId, final String password) { this.username = username; this.deviceId = deviceId; this.password = password; @@ -59,9 +59,9 @@ public class BasicAuthorizationHeader { final String usernameComponent = credentials.substring(0, credentialSeparatorIndex); final String username; - final long deviceId; + final byte deviceId; { - final Pair identifierAndDeviceId = + final Pair identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(usernameComponent); username = identifierAndDeviceId.first(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java index d6b6346f0..f19e70c2a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/OptionalAccess.java @@ -29,7 +29,7 @@ public class OptionalAccess { verify(requestAccount, accessKey, targetAccount); if (!deviceSelector.equals("*")) { - long deviceId = Long.parseLong(deviceSelector); + byte deviceId = Byte.parseByte(deviceSelector); Optional targetDevice = targetAccount.get().getDevice(deviceId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java index 4ba8db0a5..522a49964 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java @@ -26,7 +26,7 @@ public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRef } @Override - public List> handleRequestFinished(final RequestEvent requestEvent) { + public List> handleRequestFinished(final RequestEvent requestEvent) { final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY); if (initialNumber != null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java index 2629f69d9..eda57f52e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManager.java @@ -157,7 +157,7 @@ public class RegistrationLockVerificationManager { registrationRecoveryPasswordsManager.removeForNumber(updatedAccount.getNumber()); } - final List deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); + final List deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds); try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java index 2f75127ec..5e020381f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java @@ -30,5 +30,5 @@ public interface WebsocketRefreshRequirementProvider { * @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a * result of the observed request */ - List> handleRequestFinished(RequestEvent requestEvent); + List> handleRequestFinished(RequestEvent requestEvent); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java index 906056986..55693101c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticatedDevice.java @@ -7,5 +7,5 @@ package org.whispersystems.textsecuregcm.auth.grpc; import java.util.UUID; -public record AuthenticatedDevice(UUID accountIdentifier, long deviceId) { +public record AuthenticatedDevice(UUID accountIdentifier, byte deviceId) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java index 68c57552a..26577d50e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/AuthenticationUtil.java @@ -17,7 +17,7 @@ import org.whispersystems.textsecuregcm.storage.Device; public class AuthenticationUtil { static final Context.Key CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci"); - static final Context.Key CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id"); + static final Context.Key CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id"); /** * Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if @@ -30,7 +30,7 @@ public class AuthenticationUtil { */ public static AuthenticatedDevice requireAuthenticatedDevice() { @Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get(); - @Nullable final Long deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get(); + @Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get(); if (accountIdentifier != null && deviceId != null) { return new AuthenticatedDevice(accountIdentifier, deviceId); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 82c72f637..65f5b156c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -217,7 +217,7 @@ public class AccountController { @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent, @NotNull @Valid AccountAttributes attributes) { final Account account = disabledPermittedAuth.getAccount(); - final long deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId(); + final byte deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId(); final Account updatedAccount = accounts.update(account, a -> { a.getDevice(deviceId).ifPresent(d -> { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index f7ae518bd..1dbd6400d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -135,7 +135,7 @@ public class DeviceController { @Produces(MediaType.APPLICATION_JSON) @Path("/{device_id}") @ChangesDeviceEnabledState - public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) { + public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") byte deviceId) { Account account = auth.getAccount(); if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); @@ -256,7 +256,7 @@ public class DeviceController { @Path("/capabilities") public void setCapabilities(@Auth AuthenticatedAccount auth, @NotNull @Valid DeviceCapabilities capabilities) { assert (auth.getAuthenticatedDevice() != null); - final long deviceId = auth.getAuthenticatedDevice().getId(); + final byte deviceId = auth.getAuthenticatedDevice().getId(); accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 31a407ad3..7de1cb35e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -332,7 +332,7 @@ public class KeysController { return account.getDevices().stream().filter(Device::isEnabled).toList(); } try { - long id = Long.parseLong(deviceId); + byte id = Byte.parseByte(deviceId); return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of()); } catch (NumberFormatException e) { throw new WebApplicationException(Response.status(422).build()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index eff65ab67..ae9f8cd22 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -283,7 +283,7 @@ public class MessageController { checkStoryRateLimit(destination.get(), userAgent); } - final Set excludedDeviceIds; + final Set excludedDeviceIds; if (isSyncMessage) { excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId()); @@ -346,7 +346,7 @@ public class MessageController { /** * Build mapping of accounts to devices/registration IDs. */ - private Map>> buildDeviceIdAndRegistrationIdMap( + private Map>> buildDeviceIdAndRegistrationIdMap( MultiRecipientMessage multiRecipientMessage, Map accountsByServiceIdentifier) { @@ -403,7 +403,7 @@ public class MessageController { checkAccessKeys(accessKeys, accountsByServiceIdentifier.values()); } - final Map>> accountToDeviceIdAndRegistrationIdMap = + final Map>> accountToDeviceIdAndRegistrationIdMap = buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier); // We might filter out all the recipients of a story (if none have enabled stories). @@ -420,7 +420,7 @@ public class MessageController { checkStoryRateLimit(account, userAgent); } - Set deviceIds = accountToDeviceIdAndRegistrationIdMap + Set deviceIds = accountToDeviceIdAndRegistrationIdMap .getOrDefault(account, Collections.emptySet()) .stream() .map(Pair::first) @@ -678,7 +678,7 @@ public class MessageController { try { Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null); - Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); + Byte sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); envelope = incomingMessage.toEnvelope( destinationIdentifier, sourceAccount, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java index 4f8745e1c..ba79d3120 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MismatchedDevicesException.java @@ -9,19 +9,19 @@ import java.util.List; public class MismatchedDevicesException extends Exception { - private final List missingDevices; - private final List extraDevices; + private final List missingDevices; + private final List extraDevices; - public MismatchedDevicesException(List missingDevices, List extraDevices) { + public MismatchedDevicesException(List missingDevices, List extraDevices) { this.missingDevices = missingDevices; this.extraDevices = extraDevices; } - public List getMissingDevices() { + public List getMissingDevices() { return missingDevices; } - public List getExtraDevices() { + public List getExtraDevices() { return extraDevices; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java index e29e89c4e..831f2f107 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java @@ -47,7 +47,7 @@ public class ProvisioningController { rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid()); - if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0), + if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, (byte) 0), Base64.getMimeDecoder().decode(message.body()))) { throw new WebApplicationException(Response.Status.NOT_FOUND); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java index 7e914f176..484c9f9ef 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StaleDevicesException.java @@ -9,13 +9,14 @@ import java.util.List; public class StaleDevicesException extends Exception { - private final List staleDevices; - public StaleDevicesException(List staleDevices) { + private final List staleDevices; + + public StaleDevicesException(List staleDevices) { this.staleDevices = staleDevices; } - public List getStaleDevices() { + public List getStaleDevices() { return staleDevices; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountDataReportResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountDataReportResponse.java index e6276a6e5..6d799ed4f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountDataReportResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/AccountDataReportResponse.java @@ -98,7 +98,7 @@ public record AccountDataReportResponse(UUID reportId, } - public record DeviceDataReport(long id, + public record DeviceDataReport(byte id, @JsonFormat(pattern = DATE_FORMAT, timezone = UTC) Instant lastSeen, @JsonFormat(pattern = DATE_FORMAT, timezone = UTC) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index 7f37aa404..993ef9aaf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -54,7 +54,7 @@ public record ChangeNumberRequest( @Schema(description=""" A new signed elliptic-curve prekey for each enabled device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - @NotNull @Valid Map devicePniSignedPrekeys, + @NotNull @Valid Map devicePniSignedPrekeys, @Schema(description=""" A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. @@ -62,10 +62,10 @@ public record ChangeNumberRequest( If present, must contain one prekey per enabled device including this one. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + @Valid Map devicePniPqLastResortPrekeys, @Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one") - @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { + @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceResponse.java index 4539ae2a0..6c111b1ef 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceResponse.java @@ -18,12 +18,12 @@ public class DeviceResponse { private UUID pni; @JsonProperty - private long deviceId; + private byte deviceId; @VisibleForTesting public DeviceResponse() {} - public DeviceResponse(UUID uuid, UUID pni, long deviceId) { + public DeviceResponse(UUID uuid, UUID pni, byte deviceId) { this.uuid = uuid; this.pni = pni; this.deviceId = deviceId; @@ -37,7 +37,7 @@ public class DeviceResponse { return pni; } - public long getDeviceId() { + public byte getDeviceId() { return deviceId; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index edcafba94..3e0f8e0f7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -12,11 +12,11 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; -public record IncomingMessage(int type, long destinationDeviceId, int destinationRegistrationId, String content) { +public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) { public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier, @Nullable Account sourceAccount, - @Nullable Long sourceDeviceId, + @Nullable Byte sourceDeviceId, final long timestamp, final boolean story, final boolean urgent, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java index ef6b6eda4..169e803d2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MismatchedDevices.java @@ -12,9 +12,9 @@ import java.util.List; public record MismatchedDevices(@JsonProperty @Schema(description = "Devices present on the account but absent in the request") - List missingDevices, + List missingDevices, @JsonProperty @Schema(description = "Devices absent on the request but present in the account") - List extraDevices) { + List extraDevices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java index ba6421c61..edec47ab6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java @@ -40,7 +40,7 @@ public record MultiRecipientMessage( @JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class) @JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class) ServiceIdentifier uuid, - @Min(1) long deviceId, + @Min(1) byte deviceId, @Min(0) @Max(65535) int registrationId, @Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index 3db331a18..47501985a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -22,7 +22,7 @@ public record PhoneNumberIdentityKeyDistributionRequest( @JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class) @Schema(description="the new identity key for this account's phone-number identity") IdentityKey pniIdentityKey, - + @NotNull @Valid @ArraySchema( @@ -32,26 +32,26 @@ public record PhoneNumberIdentityKeyDistributionRequest( Exactly one message must be supplied for each enabled device other than the sending (primary) device. """)) List<@NotNull @Valid IncomingMessage> deviceMessages, - + @NotNull @Valid @Schema(description=""" A new signed elliptic-curve prekey for each enabled device on the account, including this one. Each must be accompanied by a valid signature from the new identity key in this request.""") - Map devicePniSignedPrekeys, - + Map devicePniSignedPrekeys, + @Schema(description=""" A new signed post-quantum last-resort prekey for each enabled device on the account, including this one. May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored. If present, must contain one prekey per enabled device including this one. Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped. Each must be accompanied by a valid signature from the new identity key in this request.""") - @Valid Map devicePniPqLastResortPrekeys, + @Valid Map devicePniPqLastResortPrekeys, @NotNull @Valid @Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.") - Map pniRegistrationIds) { + Map pniRegistrationIds) { @AssertTrue public boolean isSignatureValidOnEachSignedPreKey() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java index c818746ec..f765d7d4e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponse.java @@ -40,7 +40,7 @@ public class PreKeyResponse { @VisibleForTesting @JsonIgnore - public PreKeyResponseItem getDevice(int deviceId) { + public PreKeyResponseItem getDevice(byte deviceId) { for (PreKeyResponseItem device : devices) { if (device.getDeviceId() == deviceId) return device; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java index 0cf519c9f..68d1680fd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PreKeyResponseItem.java @@ -12,7 +12,7 @@ public class PreKeyResponseItem { @JsonProperty @Schema(description="the device ID of the device to which this item pertains") - private long deviceId; + private byte deviceId; @JsonProperty @Schema(description="the registration ID for the device") @@ -33,7 +33,8 @@ public class PreKeyResponseItem { public PreKeyResponseItem() {} - public PreKeyResponseItem(long deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey, KEMSignedPreKey pqPreKey) { + public PreKeyResponseItem(byte deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey, + KEMSignedPreKey pqPreKey) { this.deviceId = deviceId; this.registrationId = registrationId; this.signedPreKey = signedPreKey; @@ -62,7 +63,7 @@ public class PreKeyResponseItem { } @VisibleForTesting - public long getDeviceId() { + public byte getDeviceId() { return deviceId; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java index bed26a51f..d49e077b9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/StaleDevices.java @@ -12,5 +12,5 @@ import java.util.List; public record StaleDevices(@JsonProperty @Schema(description = "Devices that are no longer active") - List staleDevices) { + List staleDevices) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEvent.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEvent.java deleted file mode 100644 index a7e58c3bd..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEvent.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.entities; - -import com.fasterxml.jackson.annotation.JsonProperty; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotEmpty; - -public class UnregisteredEvent { - - @JsonProperty - @NotEmpty - private String registrationId; - - @JsonProperty - private String canonicalId; - - @JsonProperty - @NotEmpty - private String number; - - @JsonProperty - @Min(1) - private int deviceId; - - @JsonProperty - private long timestamp; - - public String getRegistrationId() { - return registrationId; - } - - public String getCanonicalId() { - return canonicalId; - } - - public String getNumber() { - return number; - } - - public int getDeviceId() { - return deviceId; - } - - public long getTimestamp() { - return timestamp; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEventList.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEventList.java deleted file mode 100644 index 91896ab0b..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/UnregisteredEventList.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.entities; - -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.LinkedList; -import java.util.List; - -public class UnregisteredEventList { - - @JsonProperty - private List devices; - - public List getDevices() { - if (devices == null) return new LinkedList<>(); - else return devices; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java new file mode 100644 index 000000000..afcdc0fa5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DeviceIdUtil.java @@ -0,0 +1,18 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.grpc; + +import io.grpc.Status; + +public class DeviceIdUtil { + + static byte validate(int deviceId) { + if (deviceId > Byte.MAX_VALUE) { + throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException(); + } + return (byte) deviceId; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java index 449cf9426..1bfadc7b2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcService.java @@ -78,18 +78,19 @@ public class DevicesGrpcService extends ReactorDevicesGrpc.DevicesImplBase { if (request.getId() == Device.PRIMARY_ID) { throw Status.INVALID_ARGUMENT.withDescription("Cannot remove primary device").asRuntimeException(); } + final byte deviceId = DeviceIdUtil.validate(request.getId()); final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedPrimaryDevice(); return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier())) .map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException)) .flatMap(account -> Flux.merge( - Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), request.getId())), - Mono.fromFuture(() -> keysManager.delete(account.getUuid(), request.getId()))) - .then(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> a.removeDevice(request.getId())))) + Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId)), + Mono.fromFuture(() -> keysManager.delete(account.getUuid(), deviceId))) + .then(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> a.removeDevice(deviceId)))) // Some messages may have arrived while we were performing the other updates; make a best effort to clear // those out, too - .then(Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), request.getId())))) + .then(Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId)))) .thenReturn(RemoveDeviceResponse.newBuilder().build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java index 48a4e3e21..43dbd2b7c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysAnonymousGrpcService.java @@ -39,12 +39,14 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony final ServiceIdentifier serviceIdentifier = ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getRequest().getTargetIdentifier()); + final byte deviceId = DeviceIdUtil.validate(request.getRequest().getDeviceId()); + return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier)) .flatMap(Mono::justOrEmpty) .switchIfEmpty(Mono.error(Status.UNAUTHENTICATED.asException())) .flatMap(targetAccount -> UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray()) - ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), request.getRequest().getDeviceId(), keysManager) + ? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager) : Mono.error(Status.UNAUTHENTICATED.asException())); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java index f0bc21a22..5f80be141 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcHelper.java @@ -27,11 +27,11 @@ import reactor.util.function.Tuples; class KeysGrpcHelper { @VisibleForTesting - static final long ALL_DEVICES = 0; + static final byte ALL_DEVICES = 0; static Mono getPreKeys(final Account targetAccount, final IdentityType identityType, - final long targetDeviceId, + final byte targetDeviceId, final KeysManager keysManager) { final Flux devices = targetDeviceId == ALL_DEVICES @@ -73,7 +73,8 @@ class KeysGrpcHelper { return builder; }) - .map(builder -> Tuples.of(device.getId(), builder.build())); + // Cast device IDs to `int` to match data types in the response object’s protobuf definition + .map(builder -> Tuples.of((int) device.getId(), builder.build())); }) .collectMap(Tuple2::getT1, Tuple2::getT2) .map(preKeyBundles -> GetPreKeysResponse.newBuilder() diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java index 6d3c3ccf9..05b921321 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/KeysGrpcService.java @@ -124,17 +124,19 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase { final ServiceIdentifier targetIdentifier = ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier()); + final byte deviceId = DeviceIdUtil.validate(request.getDeviceId()); + final String rateLimitKey = authenticatedDevice.accountIdentifier() + "." + authenticatedDevice.deviceId() + "__" + targetIdentifier.uuid() + "." + - request.getDeviceId(); + deviceId; return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey) .then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier)) .flatMap(Mono::justOrEmpty)) .switchIfEmpty(Mono.error(Status.NOT_FOUND.asException())) .flatMap(targetAccount -> - KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), request.getDeviceId(), keysManager)); + KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), deviceId, keysManager)); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index 74b1ec164..bdcf02ef9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -83,7 +83,14 @@ public class MultiRecipientMessageProvider implements MessageBodyReader Byte.MAX_VALUE) { + throw new BadRequestException("Invalid device ID"); + } + deviceId = (byte) deviceIdLong; + } int registrationId = readU16(entityStream); byte[] perRecipientKeyMaterial = entityStream.readNBytes(48); if (perRecipientKeyMaterial.length != 48) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationScheduler.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationScheduler.java index 19d5c0d27..2be3bf99c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationScheduler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ApnPushNotificationScheduler.java @@ -300,7 +300,7 @@ public class ApnPushNotificationScheduler implements Managed { } @VisibleForTesting - static Optional> getSeparated(String encoded) { + static Optional> getSeparated(String encoded) { try { if (encoded == null) return Optional.empty(); @@ -311,7 +311,7 @@ public class ApnPushNotificationScheduler implements Managed { return Optional.empty(); } - return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1]))); + return Optional.of(new Pair<>(parts[0], Byte.parseByte(parts[1]))); } catch (NumberFormatException e) { logger.warn("Badly formatted: " + encoded, e); return Optional.empty(); @@ -338,7 +338,7 @@ public class ApnPushNotificationScheduler implements Managed { final Optional maybeAccount = accountsManager.getByAccountIdentifier(UUID.fromString(parts[0])); - return maybeAccount.flatMap(account -> account.getDevice(Long.parseLong(parts[1]))) + return maybeAccount.flatMap(account -> account.getDevice(Byte.parseByte(parts[1]))) .map(device -> new Pair<>(maybeAccount.get(), device)); } catch (final NumberFormatException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java index fc7d455e1..78a43eb85 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java @@ -21,6 +21,7 @@ import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; @@ -34,7 +35,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import io.micrometer.core.instrument.Metrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; @@ -162,7 +162,8 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter connection.sync().upstream().commands().unsubscribe(getManagerPresenceChannel(managerId))); } - public void setPresent(final UUID accountUuid, final long deviceId, final DisplacedPresenceListener displacementListener) { + public void setPresent(final UUID accountUuid, final byte deviceId, + final DisplacedPresenceListener displacementListener) { try (final Timer.Context ignored = setPresenceTimer.time()) { final String presenceKey = getPresenceKey(accountUuid, deviceId); @@ -182,12 +183,12 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter deviceIds) { + public void disconnectAllPresences(final UUID accountUuid, final List deviceIds) { List presenceKeys = new ArrayList<>(); deviceIds.forEach(deviceId -> { @@ -208,7 +209,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter connection.sync().exists(getPresenceKey(accountUuid, deviceId))) == 1; } } - public boolean isLocallyPresent(final UUID accountUuid, final long deviceId) { + public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) { return displacementListenersByPresenceKey.containsKey(getPresenceKey(accountUuid, deviceId)); } - public boolean clearPresence(final UUID accountUuid, final long deviceId, final DisplacedPresenceListener listener) { + public boolean clearPresence(final UUID accountUuid, final byte deviceId, final DisplacedPresenceListener listener) { final String presenceKey = getPresenceKey(accountUuid, deviceId); if (displacementListenersByPresenceKey.remove(presenceKey, listener)) { return clearPresence(presenceKey); @@ -337,7 +338,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter { if (pushRecord != null) { final Duration latency = Duration.between(pushRecord.timestamp(), Instant.now()); @@ -114,7 +114,7 @@ public class PushLatencyManager { } @VisibleForTesting - CompletableFuture takePushRecord(final UUID accountUuid, final long deviceId) { + CompletableFuture takePushRecord(final UUID accountUuid, final byte deviceId) { final String key = getFirstUnacknowledgedPushKey(accountUuid, deviceId); return redisCluster.withCluster(connection -> { @@ -141,7 +141,7 @@ public class PushLatencyManager { }); } - private static String getFirstUnacknowledgedPushKey(final UUID accountUuid, final long deviceId) { + private static String getFirstUnacknowledgedPushKey(final UUID accountUuid, final byte deviceId) { return "push_latency::v2::" + accountUuid.toString() + "::" + deviceId; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java index 8d48e5189..6e0543ee8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/PushNotificationManager.java @@ -47,7 +47,7 @@ public class PushNotificationManager { this.pushLatencyManager = pushLatencyManager; } - public void sendNewMessageNotification(final Account destination, final long destinationDeviceId, final boolean urgent) throws NotPushRegisteredException { + public void sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException { final Device device = destination.getDevice(destinationDeviceId).orElseThrow(NotPushRegisteredException::new); final Pair tokenAndType = getToken(device); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index ac4c9498f..f6605fa56 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -34,7 +34,7 @@ public class ReceiptSender { ; } - public void sendReceipt(ServiceIdentifier sourceIdentifier, long sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) { + public void sendReceipt(ServiceIdentifier sourceIdentifier, byte sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) { if (sourceIdentifier.equals(destinationIdentifier)) { return; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index 059d85df2..e756b6274 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -223,7 +223,7 @@ public class Account { this.devices.add(device); } - public void removeDevice(final long deviceId) { + public void removeDevice(final byte deviceId) { requireNotStale(); this.devices.removeIf(device -> device.getId() == deviceId); @@ -241,7 +241,7 @@ public class Account { return getDevice(Device.PRIMARY_ID); } - public Optional getDevice(final long deviceId) { + public Optional getDevice(final byte deviceId) { requireNotStale(); return devices.stream().filter(device -> device.getId() == deviceId).findFirst(); @@ -281,15 +281,19 @@ public class Account { return getPrimaryDevice().map(Device::isEnabled).orElse(false); } - public long getNextDeviceId() { + public byte getNextDeviceId() { requireNotStale(); - long candidateId = Device.PRIMARY_ID + 1; + byte candidateId = Device.PRIMARY_ID + 1; while (getDevice(candidateId).isPresent()) { candidateId++; } + if (candidateId <= Device.PRIMARY_ID) { + throw new RuntimeException("device ID overflow"); + } + return candidateId; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 472b32db1..d4890d40d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -268,9 +268,9 @@ public class AccountsManager { public Account changeNumber(final Account account, final String targetNumber, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, - @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, + @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException { final String originalNumber = account.getNumber(); final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier(); @@ -369,9 +369,9 @@ public class AccountsManager { public Account updatePniKeys(final Account account, final IdentityKey pniIdentityKey, - final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, - final Map pniRegistrationIds) throws MismatchedDevicesException { + final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, + final Map pniRegistrationIds) throws MismatchedDevicesException { validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); final UUID pni = account.getPhoneNumberIdentifier(); @@ -395,8 +395,8 @@ public class AccountsManager { private boolean setPniKeys(final Account account, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniRegistrationIds) { + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniRegistrationIds) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { return false; } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { @@ -424,9 +424,9 @@ public class AccountsManager { } private void validateDevices(final Account account, - @Nullable final Map pniSignedPreKeys, - @Nullable final Map pniPqLastResortPreKeys, - @Nullable final Map pniRegistrationIds) throws MismatchedDevicesException { + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys, + @Nullable final Map pniRegistrationIds) throws MismatchedDevicesException { if (pniSignedPreKeys == null && pniRegistrationIds == null) { return; } else if (pniSignedPreKeys == null || pniRegistrationIds == null) { @@ -580,7 +580,7 @@ public class AccountsManager { } /** - * Specialized version of {@link #updateDevice(Account, long, Consumer)} that minimizes potentially contentious and + * Specialized version of {@link #updateDevice(Account, byte, Consumer)} that minimizes potentially contentious and * redundant updates of {@code device.lastSeen} */ public Account updateDeviceLastSeen(Account account, Device device, final long lastSeen) { @@ -741,7 +741,7 @@ public class AccountsManager { return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); } - public Account updateDevice(Account account, long deviceId, Consumer deviceUpdater) { + public Account updateDevice(Account account, byte deviceId, Consumer deviceUpdater) { return update(account, a -> { a.getDevice(deviceId).ifPresent(deviceUpdater); // assume that all updaters passed to the public method actually modify the device @@ -749,7 +749,8 @@ public class AccountsManager { }); } - public CompletableFuture updateDeviceAsync(final Account account, final long deviceId, final Consumer deviceUpdater) { + public CompletableFuture updateDeviceAsync(final Account account, final byte deviceId, + final Consumer deviceUpdater) { return updateAsync(account, a -> { a.getDevice(deviceId).ifPresent(deviceUpdater); // assume that all updaters passed to the public method actually modify the device diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java index d5e8a3549..92079e73c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManager.java @@ -43,10 +43,10 @@ public class ChangeNumberManager { public Account changeNumber(final Account account, final String number, @Nullable final IdentityKey pniIdentityKey, - @Nullable final Map deviceSignedPreKeys, - @Nullable final Map devicePqLastResortPreKeys, + @Nullable final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, @Nullable final List deviceMessages, - @Nullable final Map pniRegistrationIds) + @Nullable final Map pniRegistrationIds) throws InterruptedException, MismatchedDevicesException, StaleDevicesException { if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) { @@ -83,10 +83,10 @@ public class ChangeNumberManager { public Account updatePniKeys(final Account account, final IdentityKey pniIdentityKey, - final Map deviceSignedPreKeys, - @Nullable final Map devicePqLastResortPreKeys, + final Map deviceSignedPreKeys, + @Nullable final Map devicePqLastResortPreKeys, final List deviceMessages, - final Map pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { + final Map pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException { validateDeviceMessages(account, deviceMessages); // Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java index 37363d686..ad66fb7b4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Device.java @@ -6,11 +6,12 @@ package org.whispersystems.textsecuregcm.storage; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import java.util.List; import java.util.OptionalInt; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; -import java.util.stream.LongStream; +import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; @@ -19,13 +20,15 @@ import org.whispersystems.textsecuregcm.identity.IdentityType; public class Device { - public static final long PRIMARY_ID = 1; - public static final int MAXIMUM_DEVICE_ID = 256; + public static final byte PRIMARY_ID = 1; + public static final byte MAXIMUM_DEVICE_ID = Byte.MAX_VALUE; public static final int MAX_REGISTRATION_ID = 0x3FFF; - public static final List ALL_POSSIBLE_DEVICE_IDS = LongStream.range(1, MAXIMUM_DEVICE_ID).boxed().collect(Collectors.toList()); + public static final List ALL_POSSIBLE_DEVICE_IDS = IntStream.range(Device.PRIMARY_ID, MAXIMUM_DEVICE_ID).boxed() + .map(Integer::byteValue).collect(Collectors.toList()); + @JsonDeserialize(using = DeviceIdDeserializer.class) @JsonProperty - private long id; + private byte id; @JsonProperty private String name; @@ -135,11 +138,11 @@ public class Device { } } - public long getId() { + public byte getId() { return id; } - public void setId(long id) { + public void setId(byte id) { this.id = id; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceIdDeserializer.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceIdDeserializer.java new file mode 100644 index 000000000..649180c8e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceIdDeserializer.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import java.io.IOException; + +/** + * The built-in {@link com.fasterxml.jackson.databind.deser.std.NumberDeserializers.ByteDeserializer} will return + * negative values—both verbatim and by coercing 128…255. We prefer this invalid data to fail fast, so this + * is a simpler and stricter deserializer. + */ +public class DeviceIdDeserializer extends JsonDeserializer { + + @Override + public Byte deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + + byte value = p.getByteValue(); + + if (value < Device.PRIMARY_ID) { + throw new DeviceIdDeserializationException(); + } + + return value; + } + + static class DeviceIdDeserializationException extends IOException { + + DeviceIdDeserializationException() { + super("Invalid Device ID"); + } + + } + + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 65167ceb3..41cef17ee 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -42,12 +42,12 @@ public class KeysManager { this.dynamicConfigurationManager = dynamicConfigurationManager; } - public CompletableFuture store(final UUID identifier, final long deviceId, final List keys) { + public CompletableFuture store(final UUID identifier, final byte deviceId, final List keys) { return store(identifier, deviceId, keys, null, null, null); } public CompletableFuture store( - final UUID identifier, final long deviceId, + final UUID identifier, final byte deviceId, @Nullable final List ecKeys, @Nullable final List pqKeys, @Nullable final ECSignedPreKey ecSignedPreKey, @@ -63,7 +63,8 @@ public class KeysManager { storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys)); } - if (ecSignedPreKey != null && dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { + if (ecSignedPreKey != null && dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration() + .storeEcSignedPreKeys()) { storeFutures.add(ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey)); } @@ -74,7 +75,7 @@ public class KeysManager { return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])); } - public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { + public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { return ecSignedPreKeys.store(identifier, keys); } else { @@ -82,27 +83,30 @@ public class KeysManager { } } - public CompletableFuture storeEcSignedPreKeyIfAbsent(final UUID identifier, final long deviceId, final ECSignedPreKey signedPreKey) { + public CompletableFuture storeEcSignedPreKeyIfAbsent(final UUID identifier, final byte deviceId, + final ECSignedPreKey signedPreKey) { return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey); } - public CompletableFuture storePqLastResort(final UUID identifier, final Map keys) { + public CompletableFuture storePqLastResort(final UUID identifier, final Map keys) { return pqLastResortKeys.store(identifier, keys); } - public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) { + public CompletableFuture storeEcOneTimePreKeys(final UUID identifier, final byte deviceId, + final List preKeys) { return ecPreKeys.store(identifier, deviceId, preKeys); } - public CompletableFuture storeKemOneTimePreKeys(final UUID identifier, final long deviceId, final List preKeys) { + public CompletableFuture storeKemOneTimePreKeys(final UUID identifier, final byte deviceId, + final List preKeys) { return pqPreKeys.store(identifier, deviceId, preKeys); } - public CompletableFuture> takeEC(final UUID identifier, final long deviceId) { + public CompletableFuture> takeEC(final UUID identifier, final byte deviceId) { return ecPreKeys.take(identifier, deviceId); } - public CompletableFuture> takePQ(final UUID identifier, final long deviceId) { + public CompletableFuture> takePQ(final UUID identifier, final byte deviceId) { return pqPreKeys.take(identifier, deviceId) .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) @@ -110,26 +114,26 @@ public class KeysManager { } @VisibleForTesting - CompletableFuture> getLastResort(final UUID identifier, final long deviceId) { + CompletableFuture> getLastResort(final UUID identifier, final byte deviceId) { return pqLastResortKeys.find(identifier, deviceId); } - public CompletableFuture> getEcSignedPreKey(final UUID identifier, final long deviceId) { + public CompletableFuture> getEcSignedPreKey(final UUID identifier, final byte deviceId) { return ecSignedPreKeys.find(identifier, deviceId); } - public CompletableFuture> getPqEnabledDevices(final UUID identifier) { + public CompletableFuture> getPqEnabledDevices(final UUID identifier) { return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture(); } - public CompletableFuture getEcCount(final UUID identifier, final long deviceId) { + public CompletableFuture getEcCount(final UUID identifier, final byte deviceId) { return ecPreKeys.getCount(identifier, deviceId); } - public CompletableFuture getPqCount(final UUID identifier, final long deviceId) { + public CompletableFuture getPqCount(final UUID identifier, final byte deviceId) { return pqPreKeys.getCount(identifier, deviceId); } - + public CompletableFuture delete(final UUID accountUuid) { return CompletableFuture.allOf( ecPreKeys.delete(accountUuid), @@ -140,7 +144,7 @@ public class KeysManager { pqLastResortKeys.delete(accountUuid)); } - public CompletableFuture delete(final UUID accountUuid, final long deviceId) { + public CompletableFuture delete(final UUID accountUuid, final byte deviceId) { return CompletableFuture.allOf( ecPreKeys.delete(accountUuid, deviceId), pqPreKeys.delete(accountUuid, deviceId), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index 09860d290..d2dbc64fe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -137,7 +137,7 @@ public class MessagePersister implements Managed { for (final String queue : queuesToPersist) { final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue); - final long deviceId = MessagesCache.getDeviceIdFromQueueName(queue); + final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queue); try { persistQueue(accountUuid, deviceId); @@ -161,7 +161,7 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void persistQueue(final UUID accountUuid, final long deviceId) throws MessagePersistenceException { + void persistQueue(final UUID accountUuid, final byte deviceId) throws MessagePersistenceException { final Optional maybeAccount = accountsManager.getByAccountIdentifier(accountUuid); if (maybeAccount.isEmpty()) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index c0be72667..9386dfdbe 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -155,7 +155,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } } - public long insert(final UUID guid, final UUID destinationUuid, final long destinationDevice, + public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope message) { final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); return (long) insertTimer.record(() -> @@ -168,7 +168,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } public CompletableFuture> remove(final UUID destinationUuid, - final long destinationDevice, + final byte destinationDevice, final UUID messageGuid) { return remove(destinationUuid, destinationDevice, List.of(messageGuid)) @@ -177,7 +177,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp @SuppressWarnings("unchecked") public CompletableFuture> remove(final UUID destinationUuid, - final long destinationDevice, + final byte destinationDevice, final List messageGuids) { return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice), @@ -202,12 +202,12 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp }, messageDeletionExecutorService); } - public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) { + public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) { return readDeleteCluster.withBinaryCluster( connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0); } - public Publisher get(final UUID destinationUuid, final long destinationDevice) { + public Publisher get(final UUID destinationUuid, final byte destinationDevice) { final long earliestAllowableEphemeralTimestamp = clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); @@ -238,7 +238,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp; } - private void discardStaleEphemeralMessages(final UUID destinationUuid, final long destinationDevice, + private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice, Flux staleEphemeralMessages) { staleEphemeralMessages .map(e -> UUID.fromString(e.getServerGuid())) @@ -251,7 +251,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @VisibleForTesting - Flux getAllMessages(final UUID destinationUuid, final long destinationDevice) { + Flux getAllMessages(final UUID destinationUuid, final byte destinationDevice) { // fetch messages by page return getNextMessagePage(destinationUuid, destinationDevice, -1) @@ -284,7 +284,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp }); } - private Flux, Long>> getNextMessagePage(final UUID destinationUuid, final long destinationDevice, + private Flux, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, long messageId) { return getItemsScript.executeBinaryReactive( @@ -315,7 +315,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @VisibleForTesting - List getMessagesToPersist(final UUID accountUuid, final long destinationDevice, + List getMessagesToPersist(final UUID accountUuid, final byte destinationDevice, final int limit) { return getMessagesTimer.record(() -> { final List> scoredMessages = readDeleteCluster.withBinaryCluster( @@ -336,16 +336,14 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } public CompletableFuture clear(final UUID destinationUuid) { - final CompletableFuture[] clearFutures = new CompletableFuture[Device.MAXIMUM_DEVICE_ID]; - - for (int deviceId = 0; deviceId < Device.MAXIMUM_DEVICE_ID; deviceId++) { - clearFutures[deviceId] = clear(destinationUuid, deviceId); - } - - return CompletableFuture.allOf(clearFutures); + return CompletableFuture.allOf( + Device.ALL_POSSIBLE_DEVICE_IDS.stream() + .map(deviceId -> clear(destinationUuid, deviceId)) + .toList() + .toArray(CompletableFuture[]::new)); } - public CompletableFuture clear(final UUID destinationUuid, final long deviceId) { + public CompletableFuture clear(final UUID destinationUuid, final byte deviceId) { final Timer.Sample sample = Timer.start(); return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId), @@ -368,23 +366,23 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp String.valueOf(limit)))); } - void addQueueToPersist(final UUID accountUuid, final long deviceId) { + void addQueueToPersist(final UUID accountUuid, final byte deviceId) { readDeleteCluster.useBinaryCluster(connection -> connection.sync() .zadd(getQueueIndexKey(accountUuid, deviceId), ZAddArgs.Builder.nx(), System.currentTimeMillis(), getMessageQueueKey(accountUuid, deviceId))); } - void lockQueueForPersistence(final UUID accountUuid, final long deviceId) { + void lockQueueForPersistence(final UUID accountUuid, final byte deviceId) { readDeleteCluster.useBinaryCluster( connection -> connection.sync().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE)); } - void unlockQueueForPersistence(final UUID accountUuid, final long deviceId) { + void unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) { readDeleteCluster.useBinaryCluster( connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId))); } - public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, + public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId, final MessageAvailabilityListener listener) { final String queueName = getQueueName(destinationUuid, deviceId); @@ -500,7 +498,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @VisibleForTesting - static String getQueueName(final UUID accountUuid, final long deviceId) { + static String getQueueName(final UUID accountUuid, final byte deviceId) { return accountUuid + "::" + deviceId; } @@ -513,15 +511,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @VisibleForTesting - static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) { + static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) { return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final long deviceId) { + private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) { return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) { + private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) { return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId)); } @@ -529,7 +527,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8); } - private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) { + private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) { return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } @@ -539,7 +537,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag))); } - static long getDeviceIdFromQueueName(final String queueName) { - return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}'))); + static byte getDeviceIdFromQueueName(final String queueName) { + return Byte.parseByte(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}'))); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java index a4a85608f..aee8a7352 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDb.java @@ -83,11 +83,13 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { this.messageDeletionScheduler = Schedulers.fromExecutor(messageDeletionExecutor); } - public void store(final List messages, final UUID destinationAccountUuid, final long destinationDeviceId) { + public void store(final List messages, final UUID destinationAccountUuid, + final byte destinationDeviceId) { storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId))); } - private void storeBatch(final List messages, final UUID destinationAccountUuid, final long destinationDeviceId) { + private void storeBatch(final List messages, final UUID destinationAccountUuid, + final byte destinationDeviceId) { if (messages.size() > DYNAMO_DB_MAX_BATCH_SIZE) { throw new IllegalArgumentException("Maximum batch size of " + DYNAMO_DB_MAX_BATCH_SIZE + " exceeded with " + messages.size() + " messages"); } @@ -112,7 +114,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems)); } - public Publisher load(final UUID destinationAccountUuid, final long destinationDeviceId, + public Publisher load(final UUID destinationAccountUuid, final byte destinationDeviceId, final Integer limit) { final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); @@ -191,7 +193,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { } public CompletableFuture> deleteMessage(final UUID destinationAccountUuid, - final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { + final byte destinationDeviceId, final UUID messageUuid, final long serverTimestamp) { final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid); @@ -240,7 +242,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { .toFuture(); } - public CompletableFuture deleteAllMessagesForDevice(final UUID destinationAccountUuid, final long destinationDeviceId) { + public CompletableFuture deleteAllMessagesForDevice(final UUID destinationAccountUuid, + final byte destinationDeviceId) { final Timer.Sample sample = Timer.start(); final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid); @@ -284,8 +287,10 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { return AttributeValues.fromUUID(destinationAccountUuid); } - private static AttributeValue convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) { + private static AttributeValue convertSortKey(final byte destinationDeviceId, final long serverTimestamp, + final UUID messageUuid) { ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]); + // for compatibility - destinationDeviceId was previously `long` byteBuffer.putLong(destinationDeviceId); byteBuffer.putLong(serverTimestamp); byteBuffer.putLong(messageUuid.getMostSignificantBits()); @@ -293,8 +298,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore { return AttributeValues.fromByteBuffer(byteBuffer.flip()); } - private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final long destinationDeviceId) { + private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final byte destinationDeviceId) { ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); + // for compatibility - destinationDeviceId was previously `long` byteBuffer.putLong(destinationDeviceId); return AttributeValues.fromByteBuffer(byteBuffer.flip()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 27b13b868..e0f0476c3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -60,7 +60,7 @@ public class MessagesManager { this.messageDeletionExecutor = messageDeletionExecutor; } - public void insert(UUID destinationUuid, long destinationDevice, Envelope message) { + public void insert(UUID destinationUuid, byte destinationDevice, Envelope message) { final UUID messageGuid = UUID.randomUUID(); messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message); @@ -70,11 +70,11 @@ public class MessagesManager { } } - public boolean hasCachedMessages(final UUID destinationUuid, final long destinationDevice) { + public boolean hasCachedMessages(final UUID destinationUuid, final byte destinationDevice) { return messagesCache.hasMessages(destinationUuid, destinationDevice); } - public Mono, Boolean>> getMessagesForDevice(UUID destinationUuid, long destinationDevice, + public Mono, Boolean>> getMessagesForDevice(UUID destinationUuid, byte destinationDevice, boolean cachedMessagesOnly) { return Flux.from( @@ -84,13 +84,13 @@ public class MessagesManager { .map(envelopes -> new Pair<>(envelopes, envelopes.size() >= RESULT_SET_CHUNK_SIZE)); } - public Publisher getMessagesForDeviceReactive(UUID destinationUuid, long destinationDevice, + public Publisher getMessagesForDeviceReactive(UUID destinationUuid, byte destinationDevice, final boolean cachedMessagesOnly) { return getMessagesForDevice(destinationUuid, destinationDevice, null, cachedMessagesOnly); } - private Publisher getMessagesForDevice(UUID destinationUuid, long destinationDevice, + private Publisher getMessagesForDevice(UUID destinationUuid, byte destinationDevice, @Nullable Integer limit, final boolean cachedMessagesOnly) { final Publisher dynamoPublisher = @@ -108,13 +108,13 @@ public class MessagesManager { messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid)); } - public CompletableFuture clear(UUID destinationUuid, long deviceId) { + public CompletableFuture clear(UUID destinationUuid, byte deviceId) { return CompletableFuture.allOf( messagesCache.clear(destinationUuid, deviceId), messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId)); } - public CompletableFuture> delete(UUID destinationUuid, long destinationDeviceId, UUID guid, + public CompletableFuture> delete(UUID destinationUuid, byte destinationDeviceId, UUID guid, @Nullable Long serverTimestamp) { return messagesCache.remove(destinationUuid, destinationDeviceId, guid) .thenComposeAsync(removed -> { @@ -140,7 +140,7 @@ public class MessagesManager { */ public int persistMessages( final UUID destinationUuid, - final long destinationDeviceId, + final byte destinationDeviceId, final List messages) { final List nonEphemeralMessages = messages.stream() @@ -165,7 +165,7 @@ public class MessagesManager { public void addMessageAvailabilityListener( final UUID destinationUuid, - final long destinationDeviceId, + final byte destinationDeviceId, final MessageAvailabilityListener listener) { messagesCache.addMessageAvailabilityListener(destinationUuid, destinationDeviceId, listener); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java index 1c12e1177..fa0556e42 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java @@ -14,7 +14,7 @@ public class RefreshingAccountAndDeviceSupplier implements Supplier new RefreshingAccountAndDeviceNotFoundException("Could not find device")); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java index daa870b8c..9b9fb8253 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java @@ -31,7 +31,7 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore } @Override - protected Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final ECSignedPreKey signedPreKey) { + protected Map getItemFromPreKey(final UUID accountUuid, final byte deviceId, final ECSignedPreKey signedPreKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), @@ -54,7 +54,7 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore } } - public CompletableFuture storeIfAbsent(final UUID identifier, final long deviceId, final ECSignedPreKey signedPreKey) { + public CompletableFuture storeIfAbsent(final UUID identifier, final byte deviceId, final ECSignedPreKey signedPreKey) { return dynamoDbAsyncClient.putItem(PutItemRequest.builder() .tableName(tableName) .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java index e6720a213..335ffe788 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStore.java @@ -21,7 +21,7 @@ public class RepeatedUseKEMSignedPreKeyStore extends RepeatedUseSignedPreKeyStor } @Override - protected Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final KEMSignedPreKey signedPreKey) { + protected Map getItemFromPreKey(final UUID accountUuid, final byte deviceId, final KEMSignedPreKey signedPreKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index 471aaf9e5..f1ec9b961 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -67,7 +67,7 @@ public abstract class RepeatedUseSignedPreKeyStore> { * * @return a future that completes once the key has been stored */ - public CompletableFuture store(final UUID identifier, final long deviceId, final K signedPreKey) { + public CompletableFuture store(final UUID identifier, final byte deviceId, final K signedPreKey) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.putItem(PutItemRequest.builder() @@ -87,13 +87,13 @@ public abstract class RepeatedUseSignedPreKeyStore> { * * @return a future that completes once all keys have been stored */ - public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { + public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder() .transactItems(signedPreKeysByDeviceId.entrySet().stream() .map(entry -> { - final long deviceId = entry.getKey(); + final byte deviceId = entry.getKey(); final K signedPreKey = entry.getValue(); return TransactWriteItem.builder() @@ -117,7 +117,7 @@ public abstract class RepeatedUseSignedPreKeyStore> { * @return a future that yields an optional signed pre-key if one is available for the target device or empty if no * key could be found for the target device */ - public CompletableFuture> find(final UUID identifier, final long deviceId) { + public CompletableFuture> find(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); final CompletableFuture> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder() @@ -165,7 +165,7 @@ public abstract class RepeatedUseSignedPreKeyStore> { * * @return a future that completes once the repeated-use pre-key has been removed from the target device */ - public CompletableFuture delete(final UUID identifier, final long deviceId) { + public CompletableFuture delete(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder() @@ -175,7 +175,7 @@ public abstract class RepeatedUseSignedPreKeyStore> { .thenRun(() -> sample.stop(deleteForDeviceTimer)); } - public Flux getDeviceIdsWithKeys(final UUID identifier) { + public Flux getDeviceIdsWithKeys(final UUID identifier) { return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) .keyConditionExpression("#uuid = :uuid") @@ -186,10 +186,10 @@ public abstract class RepeatedUseSignedPreKeyStore> { .consistentRead(true) .build()) .items()) - .map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n())); + .map(item -> Byte.parseByte(item.get(KEY_DEVICE_ID).n())); } - protected static Map getPrimaryKey(final UUID identifier, final long deviceId) { + protected static Map getPrimaryKey(final UUID identifier, final byte deviceId) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID, getSortKey(deviceId)); @@ -199,11 +199,12 @@ public abstract class RepeatedUseSignedPreKeyStore> { return AttributeValues.fromUUID(accountUuid); } - protected static AttributeValue getSortKey(final long deviceId) { - return AttributeValues.fromLong(deviceId); + protected static AttributeValue getSortKey(final byte deviceId) { + return AttributeValues.fromInt(deviceId); } - protected abstract Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final K signedPreKey); + protected abstract Map getItemFromPreKey(final UUID accountUuid, final byte deviceId, + final K signedPreKey); protected abstract K getPreKeyFromItem(final Map item); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java index 025f10b81..536884179 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -24,7 +24,7 @@ public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { } @Override - protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final ECPreKey preKey) { + protected Map getItemFromPreKey(final UUID identifier, final byte deviceId, final ECPreKey preKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java index 2e54fad37..b373d0e57 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java @@ -21,7 +21,7 @@ public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore getItemFromPreKey(final UUID identifier, final long deviceId, final KEMSignedPreKey signedPreKey) { + protected Map getItemFromPreKey(final UUID identifier, final byte deviceId, final KEMSignedPreKey signedPreKey) { return Map.of( KEY_ACCOUNT_UUID, getPartitionKey(identifier), KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java index 95e086544..305f80147 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -36,11 +36,11 @@ import software.amazon.awssdk.services.dynamodb.model.Select; /** * A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key - * store's {@link #take(UUID, long)} method are guaranteed to be returned exactly once, and repeated calls will never + * store's {@link #take(UUID, byte)} method are guaranteed to be returned exactly once, and repeated calls will never * yield the same key. *

* Each {@link Account} may have one or more {@link Device devices}. Clients should regularly check their - * supply of single-use pre-keys (see {@link #getCount(UUID, long)}) and upload new keys when their supply runs low. In + * supply of single-use pre-keys (see {@link #getCount(UUID, byte)}) and upload new keys when their supply runs low. In * the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party * may fall back to using the device's repeated-use ("last-resort") signed pre-key instead. */ @@ -91,7 +91,7 @@ public abstract class SingleUsePreKeyStore> { * @return a future that completes when all previously-stored keys have been removed and the given collection of * pre-keys has been stored in its place */ - public CompletableFuture store(final UUID identifier, final long deviceId, final List preKeys) { + public CompletableFuture store(final UUID identifier, final byte deviceId, final List preKeys) { final Timer.Sample sample = Timer.start(); return Mono.fromFuture(() -> delete(identifier, deviceId)) @@ -103,7 +103,7 @@ public abstract class SingleUsePreKeyStore> { .thenRun(() -> sample.stop(storeKeyBatchTimer)); } - private CompletableFuture store(final UUID identifier, final long deviceId, final K preKey) { + private CompletableFuture store(final UUID identifier, final byte deviceId, final K preKey) { final Timer.Sample sample = Timer.start(); return dynamoDbAsyncClient.putItem(PutItemRequest.builder() @@ -124,7 +124,7 @@ public abstract class SingleUsePreKeyStore> { * @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are * available for the target device */ - public CompletableFuture> take(final UUID identifier, final long deviceId) { + public CompletableFuture> take(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); final AttributeValue partitionKey = getPartitionKey(identifier); final AtomicInteger keysConsidered = new AtomicInteger(0); @@ -169,7 +169,7 @@ public abstract class SingleUsePreKeyStore> { * @return a future that yields the approximate number of single-use pre-keys currently available for the target * device */ - public CompletableFuture getCount(final UUID identifier, final long deviceId) { + public CompletableFuture getCount(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); // Getting an accurate count from DynamoDB can be very confusing. See: @@ -230,7 +230,7 @@ public abstract class SingleUsePreKeyStore> { * @return a future that completes when all single-use pre-keys have been removed for the target device */ - public CompletableFuture delete(final UUID identifier, final long deviceId) { + public CompletableFuture delete(final UUID identifier, final byte deviceId) { final Timer.Sample sample = Timer.start(); return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() @@ -267,20 +267,20 @@ public abstract class SingleUsePreKeyStore> { return AttributeValues.fromUUID(accountUuid); } - protected static AttributeValue getSortKey(final long deviceId, final long keyId) { + protected static AttributeValue getSortKey(final byte deviceId, final long keyId) { final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); byteBuffer.putLong(deviceId); byteBuffer.putLong(keyId); return AttributeValues.fromByteBuffer(byteBuffer.flip()); } - private static AttributeValue getSortKeyPrefix(final long deviceId) { + private static AttributeValue getSortKeyPrefix(final byte deviceId) { final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); byteBuffer.putLong(deviceId); return AttributeValues.fromByteBuffer(byteBuffer.flip()); } - protected abstract Map getItemFromPreKey(final UUID identifier, final long deviceId, + protected abstract Map getItemFromPreKey(final UUID identifier, final byte deviceId, final K preKey); protected abstract K getPreKeyFromItem(final Map item); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java index 1c26c731e..38c8c43c4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidator.java @@ -8,7 +8,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; -import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -24,7 +23,7 @@ public class DestinationDeviceValidator { * @see #validateRegistrationIds(Account, Stream, boolean) */ public static void validateRegistrationIds(final Account account, final Collection messages, - Function getDeviceId, Function getRegistrationId, boolean usePhoneNumberIdentity) + Function getDeviceId, Function getRegistrationId, boolean usePhoneNumberIdentity) throws StaleDevicesException { validateRegistrationIds(account, messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))), @@ -47,13 +46,13 @@ public class DestinationDeviceValidator { * account does not have a corresponding device or if the registration IDs do not match */ public static void validateRegistrationIds(final Account account, - final Stream> deviceIdAndRegistrationIdStream, + final Stream> deviceIdAndRegistrationIdStream, final boolean usePhoneNumberIdentity) throws StaleDevicesException { - final List staleDevices = deviceIdAndRegistrationIdStream + final List staleDevices = deviceIdAndRegistrationIdStream .filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0) .filter(deviceIdAndRegistrationId -> { - final long deviceId = deviceIdAndRegistrationId.first(); + final byte deviceId = deviceIdAndRegistrationId.first(); final int registrationId = deviceIdAndRegistrationId.second(); boolean registrationIdMatches = account.getDevice(deviceId) .map(device -> registrationId == (usePhoneNumberIdentity @@ -86,19 +85,19 @@ public class DestinationDeviceValidator { * account */ public static void validateCompleteDeviceList(final Account account, - final Set messageDeviceIds, - final Set excludedDeviceIds) throws MismatchedDevicesException { + final Set messageDeviceIds, + final Set excludedDeviceIds) throws MismatchedDevicesException { - final Set accountDeviceIds = account.getDevices().stream() + final Set accountDeviceIds = account.getDevices().stream() .filter(Device::isEnabled) .map(Device::getId) .filter(deviceId -> !excludedDeviceIds.contains(deviceId)) .collect(Collectors.toSet()); - final Set missingDeviceIds = new HashSet<>(accountDeviceIds); + final Set missingDeviceIds = new HashSet<>(accountDeviceIds); missingDeviceIds.removeAll(messageDeviceIds); - final Set extraDeviceIds = new HashSet<>(messageDeviceIds); + final Set extraDeviceIds = new HashSet<>(messageDeviceIds); extraDeviceIds.removeAll(accountDeviceIds); if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java index 38008bf76..49d477a16 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningAddress.java @@ -10,7 +10,7 @@ import java.util.Base64; public class ProvisioningAddress extends WebsocketAddress { - public ProvisioningAddress(String address, int id) { + public ProvisioningAddress(String address, byte id) { super(address, id); } @@ -26,6 +26,6 @@ public class ProvisioningAddress extends WebsocketAddress { byte[] random = new byte[16]; new SecureRandom().nextBytes(random); - return new ProvisioningAddress(Base64.getUrlEncoder().withoutPadding().encodeToString(random), 0); + return new ProvisioningAddress(Base64.getUrlEncoder().withoutPadding().encodeToString(random), (byte) 0); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java index df670064f..054479865 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java @@ -10,9 +10,9 @@ import org.whispersystems.textsecuregcm.storage.PubSubAddress; public class WebsocketAddress implements PubSubAddress { private final String number; - private final long deviceId; + private final byte deviceId; - public WebsocketAddress(String number, long deviceId) { + public WebsocketAddress(String number, byte deviceId) { this.number = number; this.deviceId = deviceId; } @@ -26,7 +26,7 @@ public class WebsocketAddress implements PubSubAddress { } this.number = parts[0]; - this.deviceId = Long.parseLong(parts[1]); + this.deviceId = Byte.parseByte(parts[1]); } catch (NumberFormatException e) { throw new InvalidWebsocketAddressException(e); } @@ -36,7 +36,7 @@ public class WebsocketAddress implements PubSubAddress { return number; } - public long getDeviceId() { + public byte getDeviceId() { return deviceId; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java index 770e23b05..5065ca0fc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java @@ -41,7 +41,7 @@ public class MigrateSignedECPreKeysCommand extends AbstractSinglePassCrawlAccoun accounts.flatMap(account -> Flux.fromIterable(account.getDevices()) .flatMap(device -> { - final List> keys = new ArrayList<>(2); + final List> keys = new ArrayList<>(2); if (device.getSignedPreKey(IdentityType.ACI) != null) { keys.add(Tuples.of(account.getUuid(), device.getId(), device.getSignedPreKey(IdentityType.ACI))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java index 814001764..21a0f830d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java @@ -36,7 +36,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand deviceIds = namespace.getList("deviceIds"); + final List deviceIds = namespace.getList("deviceIds"); final CommandDependencies deps = CommandDependencies.build("unlink-device", environment, configuration); @@ -68,7 +68,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand a.removeDevice(deviceId)); diff --git a/service/src/main/proto/org/signal/chat/device.proto b/service/src/main/proto/org/signal/chat/device.proto index 75de93101..eeb19fd6c 100644 --- a/service/src/main/proto/org/signal/chat/device.proto +++ b/service/src/main/proto/org/signal/chat/device.proto @@ -55,7 +55,7 @@ message GetDevicesResponse { /** * The identifier for the device within an account. */ - uint64 id = 1; + uint32 id = 1; /** * A sequence of bytes that encodes an encrypted human-readable name for @@ -86,7 +86,7 @@ message RemoveDeviceRequest { /** * The identifier for the device to remove from the authenticated account. */ - uint64 id = 1; + uint32 id = 1; } message SetDeviceNameRequest { diff --git a/service/src/main/proto/org/signal/chat/keys.proto b/service/src/main/proto/org/signal/chat/keys.proto index 317f64765..d900caf86 100644 --- a/service/src/main/proto/org/signal/chat/keys.proto +++ b/service/src/main/proto/org/signal/chat/keys.proto @@ -154,7 +154,7 @@ message GetPreKeysRequest { * retrieve pre-keys. If not set, pre-keys are returned for all devices * associated with the targeted account. */ - uint64 device_id = 2; + uint32 device_id = 2; } message GetPreKeysAnonymousRequest { @@ -199,7 +199,7 @@ message GetPreKeysResponse { /** * A map of device IDs to pre-key "bundles" for the targeted account. */ - map pre_keys = 2; + map pre_keys = 2; } message SetOneTimeEcPreKeysRequest { @@ -276,4 +276,3 @@ message CheckIdentityKeyResponse { */ bytes identity_key = 2; } - diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 3c8be9118..3c415416b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -43,7 +43,7 @@ import java.util.Set; import java.util.UUID; import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.LongStream; +import java.util.stream.IntStream; import java.util.stream.Stream; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -89,7 +89,7 @@ class AuthEnablementRefreshRequirementProviderTest { private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class); private final Account account = new Account(); - private final Device authenticatedDevice = DevicesHelper.createDevice(1L); + private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID); private final Supplier> principalSupplier = () -> Optional.of( new TestPrincipal("test", account, authenticatedDevice)); @@ -126,7 +126,8 @@ class AuthEnablementRefreshRequirementProviderTest { final UUID uuid = UUID.randomUUID(); account.setUuid(uuid); account.addDevice(authenticatedDevice); - LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId))); + IntStream.range(2, 4) + .forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId))); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); @@ -137,22 +138,22 @@ class AuthEnablementRefreshRequirementProviderTest { @Test void testBuildDevicesEnabled() { - final long disabledDeviceId = 3L; + final byte disabledDeviceId = 3; final Account account = mock(Account.class); final List devices = new ArrayList<>(); when(account.getDevices()).thenReturn(devices); - LongStream.range(1, 5) + IntStream.range(1, 5) .forEach(id -> { final Device device = mock(Device.class); - when(device.getId()).thenReturn(id); + when(device.getId()).thenReturn((byte) id); when(device.isEnabled()).thenReturn(id != disabledDeviceId); devices.add(device); }); - final Map devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account); + final Map devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account); assertEquals(4, devicesEnabled.size()); @@ -168,7 +169,7 @@ class AuthEnablementRefreshRequirementProviderTest { @ParameterizedTest @MethodSource - void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) { + void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) { assert initialEnabled.size() == finalEnabled.size(); assert account.getPrimaryDevice().orElseThrow().isEnabled(); @@ -199,13 +200,16 @@ class AuthEnablementRefreshRequirementProviderTest { } static Stream testDeviceEnabledChanged() { + final byte deviceId1 = Device.PRIMARY_ID; + final byte deviceId2 = 2; + final byte deviceId3 = 3; return Stream.of( - Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)), - Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)), - Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)), - Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)), - Arguments.of(Map.of(2L, false, 3L, true), Map.of(2L, true, 3L, true)), - Arguments.of(Map.of(2L, true, 3L, false), Map.of(2L, true, 3L, true)) + Arguments.of(Map.of(deviceId1, false, deviceId2, false), Map.of(deviceId1, true, deviceId2, false)), + Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)), + Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)), + Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)), + Arguments.of(Map.of(deviceId2, false, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)), + Arguments.of(Map.of(deviceId2, true, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)) ); } @@ -227,9 +231,9 @@ class AuthEnablementRefreshRequirementProviderTest { assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size()); - verify(clientPresenceManager).disconnectPresence(account.getUuid(), 1); - verify(clientPresenceManager).disconnectPresence(account.getUuid(), 2); - verify(clientPresenceManager).disconnectPresence(account.getUuid(), 3); + verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1); + verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2); + verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3); } @ParameterizedTest @@ -237,13 +241,13 @@ class AuthEnablementRefreshRequirementProviderTest { void testDeviceRemoved(final int removedDeviceCount) { assert account.getPrimaryDevice().orElseThrow().isEnabled(); - final List initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList()); + final List initialDeviceIds = account.getDevices().stream().map(Device::getId).toList(); - final List deletedDeviceIds = account.getDevices().stream() + final List deletedDeviceIds = account.getDevices().stream() .map(Device::getId) - .filter(deviceId -> deviceId != 1L) + .filter(deviceId -> deviceId != Device.PRIMARY_ID) .limit(removedDeviceCount) - .collect(Collectors.toList()); + .toList(); assert deletedDeviceIds.size() == removedDeviceCount; @@ -269,9 +273,9 @@ class AuthEnablementRefreshRequirementProviderTest { void testPrimaryDeviceDisabledAndDeviceRemoved() { assert account.getPrimaryDevice().orElseThrow().isEnabled(); - final Set initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()); + final Set initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()); - final long deletedDeviceId = 2L; + final byte deletedDeviceId = 2; assertTrue(initialDeviceIds.remove(deletedDeviceId)); final Response response = resources.getJerseyTest() @@ -427,11 +431,11 @@ class AuthEnablementRefreshRequirementProviderTest { @POST @Path("/account/devices/enabled") @ChangesDeviceEnabledState - public String setEnabled(@Auth TestPrincipal principal, Map deviceIdsEnabled) { + public String setEnabled(@Auth TestPrincipal principal, Map deviceIdsEnabled) { final StringBuilder response = new StringBuilder(); - for (Entry deviceIdEnabled : deviceIdsEnabled.entrySet()) { + for (Entry deviceIdEnabled : deviceIdsEnabled.entrySet()) { final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow(); DevicesHelper.setEnabled(device, deviceIdEnabled.getValue()); @@ -462,7 +466,7 @@ class AuthEnablementRefreshRequirementProviderTest { public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) { Arrays.stream(deviceIds.split(",")) - .map(Long::valueOf) + .map(Byte::valueOf) .forEach(auth.getAccount()::removeDevice); return "Removed device(s) " + deviceIds; @@ -471,7 +475,7 @@ class AuthEnablementRefreshRequirementProviderTest { @POST @Path("/account/disablePrimaryDeviceAndDeleteDevice/{deviceId}") @ChangesDeviceEnabledState - public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) { + public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") byte deviceId) { DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice().orElseThrow(), false); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticatorTest.java index 431296573..dcaca875c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticatorTest.java @@ -150,7 +150,7 @@ class BaseAccountAuthenticatorTest { @Test void testAuthenticate() { final UUID uuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final String password = "12345"; final Account account = mock(Account.class); @@ -180,7 +180,7 @@ class BaseAccountAuthenticatorTest { @Test void testAuthenticateNonDefaultDevice() { final UUID uuid = UUID.randomUUID(); - final long deviceId = 2; + final byte deviceId = 2; final String password = "12345"; final Account account = mock(Account.class); @@ -214,7 +214,7 @@ class BaseAccountAuthenticatorTest { @CartesianTest.Values(booleans = {true, false}) final boolean deviceEnabled, @CartesianTest.Values(booleans = {true, false}) final boolean authenticatedDeviceIsPrimary) { final UUID uuid = UUID.randomUUID(); - final long deviceId = authenticatedDeviceIsPrimary ? 1 : 2; + final byte deviceId = (byte) (authenticatedDeviceIsPrimary ? 1 : 2); final String password = "12345"; final Account account = mock(Account.class); @@ -253,7 +253,7 @@ class BaseAccountAuthenticatorTest { @Test void testAuthenticateV1() { final UUID uuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final String password = "12345"; final Account account = mock(Account.class); @@ -290,7 +290,7 @@ class BaseAccountAuthenticatorTest { @Test void testAuthenticateDeviceNotFound() { final UUID uuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final String password = "12345"; final Account account = mock(Account.class); @@ -312,13 +312,13 @@ class BaseAccountAuthenticatorTest { baseAccountAuthenticator.authenticate(new BasicCredentials(uuid + "." + (deviceId + 1), password), true); assertThat(maybeAuthenticatedAccount).isEmpty(); - verify(account).getDevice(deviceId + 1); + verify(account).getDevice((byte) (deviceId + 1)); } @Test void testAuthenticateIncorrectPassword() { final UUID uuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final String password = "12345"; final Account account = mock(Account.class); @@ -365,8 +365,9 @@ class BaseAccountAuthenticatorTest { @ParameterizedTest @MethodSource - void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier, final long expectedDeviceId) { - final Pair identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username); + void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier, + final byte expectedDeviceId) { + final Pair identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username); assertEquals(expectedIdentifier, identifierAndDeviceId.first()); assertEquals(expectedDeviceId, identifierAndDeviceId.second()); @@ -376,7 +377,7 @@ class BaseAccountAuthenticatorTest { return Stream.of( Arguments.of("", "", Device.PRIMARY_ID), Arguments.of("test", "test", Device.PRIMARY_ID), - Arguments.of("test.7", "test", 7)); + Arguments.of("test.7", "test", (byte) 7)); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java index 3e40c715a..cff6e3c36 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java @@ -34,11 +34,11 @@ class CertificateGeneratorTest { final CertificateGenerator certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(SIGNING_CERTIFICATE), Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(account.getNumber()).thenReturn("+18005551234"); - when(device.getId()).thenReturn(4L); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getNumber()).thenReturn("+18005551234"); + when(device.getId()).thenReturn((byte) 4); - assertTrue(certificateGenerator.createFor(account, device, true).length > 0); - assertTrue(certificateGenerator.createFor(account, device, false).length > 0); + assertTrue(certificateGenerator.createFor(account, device, true).length > 0); + assertTrue(certificateGenerator.createFor(account, device, false).length > 0); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java index cc0afbcca..f3e2bc81f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/OptionalAccessTest.java @@ -32,7 +32,7 @@ class OptionalAccessTest { void testUnidentifiedMissingTargetDevice() { Account account = mock(Account.class); when(account.isEnabled()).thenReturn(true); - when(account.getDevice(eq(10))).thenReturn(Optional.empty()); + when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty()); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes())); try { @@ -46,7 +46,7 @@ class OptionalAccessTest { void testUnidentifiedBadTargetDevice() { Account account = mock(Account.class); when(account.isEnabled()).thenReturn(true); - when(account.getDevice(eq(10))).thenReturn(Optional.empty()); + when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty()); when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes())); try { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java index a4fd52df8..1db4e8c45 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/MockAuthenticationInterceptor.java @@ -18,9 +18,9 @@ import org.whispersystems.textsecuregcm.util.Pair; public class MockAuthenticationInterceptor implements ServerInterceptor { @Nullable - private Pair authenticatedDevice; + private Pair authenticatedDevice; - public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) { + public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) { authenticatedDevice = new Pair<>(accountIdentifier, deviceId); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index eb4b5f728..ba9ba63f3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -10,8 +10,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.Mockito.anyLong; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; @@ -299,7 +299,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(204); verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000")); - verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any()); } @Test @@ -328,7 +328,7 @@ class AccountControllerTest { verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first")); verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second")); - verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any()); } @Test @@ -344,7 +344,7 @@ class AccountControllerTest { verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first")); verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null); - verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any()); + verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any()); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index e6ff0d951..1f4c8b19c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -160,7 +160,7 @@ class AccountControllerV2Test { } when(updatedAccount.getDevices()).thenReturn(devices); - for (long i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); when(updatedAccount.getDevice(i)).thenReturn(d); } @@ -481,7 +481,7 @@ class AccountControllerV2Test { when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); when(updatedAccount.getDevices()).thenReturn(devices); - for (long i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); when(updatedAccount.getDevice(i)).thenReturn(d); } @@ -661,7 +661,7 @@ class AccountControllerV2Test { assertEquals(account.isUnrestrictedUnidentifiedAccess(), structuredResponse.data().account().allowSealedSenderFromAnyone()); - final Set deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()); + final Set deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()); // all devices should be present structuredResponse.data().devices().forEach(deviceDataReport -> { @@ -704,8 +704,8 @@ class AccountControllerV2Test { buildTestAccountForDataReport(UUID.randomUUID(), exampleNumber1, true, true, Collections.emptyList(), - List.of(new DeviceData(1, account1Device1LastSeen, account1Device1Created, null), - new DeviceData(2, account1Device2LastSeen, account1Device2Created, "OWP"))), + List.of(new DeviceData(Device.PRIMARY_ID, account1Device1LastSeen, account1Device1Created, null), + new DeviceData((byte) 2, account1Device2LastSeen, account1Device2Created, "OWP"))), String.format(""" # Account Phone number: %s @@ -730,7 +730,7 @@ class AccountControllerV2Test { buildTestAccountForDataReport(UUID.randomUUID(), account2PhoneNumber, false, true, List.of(new AccountBadge("badge_a", badgeAExpiration, true)), - List.of(new DeviceData(1, account2Device1LastSeen, account2Device1Created, "OWI"))), + List.of(new DeviceData(Device.PRIMARY_ID, account2Device1LastSeen, account2Device1Created, "OWI"))), String.format(""" # Account Phone number: %s @@ -756,7 +756,7 @@ class AccountControllerV2Test { List.of( new AccountBadge("badge_b", badgeBExpiration, true), new AccountBadge("badge_c", badgeCExpiration, false)), - List.of(new DeviceData(1, account3Device1LastSeen, account3Device1Created, "OWA"))), + List.of(new DeviceData(Device.PRIMARY_ID, account3Device1LastSeen, account3Device1Created, "OWA"))), String.format(""" # Account Phone number: %s @@ -825,7 +825,7 @@ class AccountControllerV2Test { return account; } - private record DeviceData(long id, Instant lastSeen, Instant created, @Nullable String userAgent) { + private record DeviceData(byte id, Instant lastSeen, Instant created, @Nullable String userAgent) { } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 63c379f35..8be6a67c1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -8,7 +8,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.eq; @@ -99,6 +99,8 @@ class DeviceControllerTest { private static Map deviceConfiguration = new HashMap<>(); private static TestClock testClock = TestClock.now(); + private static final byte NEXT_DEVICE_ID = 42; + private static DeviceController deviceController = new DeviceController( generateLinkDeviceSecret(), accountsManager, @@ -137,9 +139,9 @@ class DeviceControllerTest { when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter); - when(primaryDevice.getId()).thenReturn(1L); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(account.getNextDeviceId()).thenReturn(42L); + when(account.getNextDeviceId()).thenReturn(NEXT_DEVICE_ID); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); @@ -154,9 +156,9 @@ class DeviceControllerTest { AccountsHelper.setupMockUpdate(accountsManager); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null)); + when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); } @AfterEach @@ -199,9 +201,9 @@ class DeviceControllerTest { MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class); - assertThat(response.getDeviceId()).isEqualTo(42L); + assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID); - verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); + verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID)); verify(commands).set(anyString(), anyString(), any()); } @@ -315,7 +317,7 @@ class DeviceControllerTest { .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class); - assertThat(response.getDeviceId()).isEqualTo(42L); + assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID); final ArgumentCaptor deviceCaptor = ArgumentCaptor.forClass(Device.class); verify(account).addDevice(deviceCaptor.capture()); @@ -335,7 +337,7 @@ class DeviceControllerTest { expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()), () -> assertNull(device.getGcmId())); - verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); + verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID)); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get())); verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get())); verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); @@ -751,7 +753,7 @@ class DeviceControllerTest { // this is a static mock, so it might have previous invocations clearInvocations(AuthHelper.VALID_ACCOUNT); - final long deviceId = 2; + final byte deviceId = 2; final Response response = resources .getJerseyTest() @@ -785,10 +787,10 @@ class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(403); - verify(messagesManager, never()).clear(any(), anyLong()); + verify(messagesManager, never()).clear(any(), anyByte()); verify(accountsManager, never()).update(eq(AuthHelper.VALID_ACCOUNT), any()); - verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyLong()); - verify(keysManager, never()).delete(any(), anyLong()); + verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyByte()); + verify(keysManager, never()).delete(any(), anyByte()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index b36d443c3..bb9e38989 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.clearInvocations; @@ -84,6 +84,11 @@ class KeysControllerTest { private static final UUID NOT_EXISTS_UUID = UUID.randomUUID(); + private static final byte SAMPLE_DEVICE_ID = 1; + private static final byte SAMPLE_DEVICE_ID2 = 2; + private static final byte SAMPLE_DEVICE_ID3 = 3; + private static final byte SAMPLE_DEVICE_ID4 = 4; + private static final int SAMPLE_REGISTRATION_ID = 999; private static final int SAMPLE_REGISTRATION_ID2 = 1002; private static final int SAMPLE_REGISTRATION_ID4 = 1555; @@ -180,6 +185,11 @@ class KeysControllerTest { final List allDevices = List.of(sampleDevice, sampleDevice2, sampleDevice3, sampleDevice4); + final byte sampleDeviceId = 1; + final byte sampleDevice2Id = 2; + final byte sampleDevice3Id = 3; + final byte sampleDevice4Id = 4; + AccountsHelper.setupMockUpdate(accounts); when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID); @@ -199,18 +209,18 @@ class KeysControllerTest { when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2); when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3); when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null); - when(sampleDevice.getId()).thenReturn(1L); - when(sampleDevice2.getId()).thenReturn(2L); - when(sampleDevice3.getId()).thenReturn(3L); - when(sampleDevice4.getId()).thenReturn(4L); + when(sampleDevice.getId()).thenReturn(sampleDeviceId); + when(sampleDevice2.getId()).thenReturn(sampleDevice2Id); + when(sampleDevice3.getId()).thenReturn(sampleDevice3Id); + when(sampleDevice4.getId()).thenReturn(sampleDevice4Id); when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI); - when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice)); - when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2)); - when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3)); - when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4)); - when(existsAccount.getDevice(22L)).thenReturn(Optional.empty()); + when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice)); + when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2)); + when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3)); + when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4)); + when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty()); when(existsAccount.getDevices()).thenReturn(allDevices); when(existsAccount.isEnabled()).thenReturn(true); when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY); @@ -225,17 +235,21 @@ class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.store(any(), anyByte(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI))); - when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI))); + when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); + when(KEYS.takeEC(EXISTS_PNI, sampleDeviceId)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI))); + when(KEYS.takePQ(EXISTS_PNI, sampleDeviceId)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI))); - when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5)); - when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5)); + when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); + when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5)); when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY); @@ -267,8 +281,8 @@ class KeysControllerTest { assertThat(result.getCount()).isEqualTo(5); assertThat(result.getPqCount()).isEqualTo(5); - verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1); - verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1); + verify(KEYS).getEcCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getPqCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID); } @Test @@ -284,7 +298,7 @@ class KeysControllerTest { verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any()); - verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any()); + verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any()); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(Device.PRIMARY_ID, test)); } @@ -303,7 +317,7 @@ class KeysControllerTest { verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey)); verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any()); - verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any()); + verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any()); verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(Device.PRIMARY_ID, replacementKey)); } @@ -329,20 +343,20 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertThat(result.getDevice(1).getPqPreKey()).isNull(); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @Test void validSingleRequestPqTestNoPqKeysV2() { - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(Optional.empty())); PreKeyResponse result = resources.getJerseyTest() .target(String.format("/v2/keys/%s/1", EXISTS_UUID)) @@ -353,15 +367,15 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertThat(result.getDevice(1).getPqPreKey()).isNull(); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).takePQ(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -376,15 +390,15 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).takePQ(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -398,14 +412,14 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); - assertThat(result.getDevice(1).getPqPreKey()).isNull(); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_PNI, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); + verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -420,15 +434,15 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); - assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_PNI, 1); - verify(KEYS).takePQ(EXISTS_PNI, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); + verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -444,14 +458,14 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey()); - assertThat(result.getDevice(1).getPqPreKey()).isNull(); - assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull(); + assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_PNI, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1); + verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -481,14 +495,14 @@ class KeysControllerTest { assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); assertThat(result.getDevicesCount()).isEqualTo(1); - assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey()); - assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey()); - assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI), - result.getDevice(1).getSignedPreKey()); + assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey()); + assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()); + assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI), + result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey()); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).takePQ(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); verifyNoMoreInteractions(KEYS); } @@ -534,10 +548,14 @@ class KeysControllerTest { @Test void validMultiRequestTestV2() { - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2))); - when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); - when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) @@ -548,56 +566,62 @@ class KeysControllerTest { assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); - ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey(); - ECPreKey preKey = results.getDevice(1).getPreKey(); - long registrationId = results.getDevice(1).getRegistrationId(); - long deviceId = results.getDevice(1).getDeviceId(); + ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); + ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey(); + long registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId(); + byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId(); assertEquals(SAMPLE_KEY, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); - assertThat(deviceId).isEqualTo(1); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID); - signedPreKey = results.getDevice(2).getSignedPreKey(); - preKey = results.getDevice(2).getPreKey(); - registrationId = results.getDevice(2).getRegistrationId(); - deviceId = results.getDevice(2).getDeviceId(); + signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey(); + preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey(); + registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId(); + deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId(); assertEquals(SAMPLE_KEY2, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); - assertThat(deviceId).isEqualTo(2); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2); - signedPreKey = results.getDevice(4).getSignedPreKey(); - preKey = results.getDevice(4).getPreKey(); - registrationId = results.getDevice(4).getRegistrationId(); - deviceId = results.getDevice(4).getDeviceId(); + signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey(); + preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey(); + registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId(); + deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId(); assertEquals(SAMPLE_KEY4, preKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(signedPreKey).isNull(); - assertThat(deviceId).isEqualTo(4); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).takeEC(EXISTS_UUID, 2); - verify(KEYS).takeEC(EXISTS_UUID, 4); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); verifyNoMoreInteractions(KEYS); } @Test void validMultiRequestPqTestV2() { - when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.takeEC(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(KEYS.takePQ(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); - when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); - when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); - when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); - when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); - when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); - when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3))); + when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2))); + when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn( + CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3))); PreKeyResponse results = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) @@ -609,51 +633,51 @@ class KeysControllerTest { assertThat(results.getDevicesCount()).isEqualTo(3); assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI)); - ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey(); - ECPreKey preKey = results.getDevice(1).getPreKey(); - KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey(); - long registrationId = results.getDevice(1).getRegistrationId(); - long deviceId = results.getDevice(1).getDeviceId(); + ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey(); + ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey(); + KEMSignedPreKey pqPreKey = results.getDevice(SAMPLE_DEVICE_ID).getPqPreKey(); + int registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId(); + byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId(); assertEquals(SAMPLE_KEY, preKey); assertEquals(SAMPLE_PQ_KEY, pqPreKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID); assertEquals(SAMPLE_SIGNED_KEY, signedPreKey); - assertThat(deviceId).isEqualTo(1); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID); - signedPreKey = results.getDevice(2).getSignedPreKey(); - preKey = results.getDevice(2).getPreKey(); - pqPreKey = results.getDevice(2).getPqPreKey(); - registrationId = results.getDevice(2).getRegistrationId(); - deviceId = results.getDevice(2).getDeviceId(); + signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey(); + preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey(); + pqPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getPqPreKey(); + registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId(); + deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId(); assertThat(preKey).isNull(); assertEquals(SAMPLE_PQ_KEY2, pqPreKey); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2); assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey); - assertThat(deviceId).isEqualTo(2); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2); - signedPreKey = results.getDevice(4).getSignedPreKey(); - preKey = results.getDevice(4).getPreKey(); - pqPreKey = results.getDevice(4).getPqPreKey(); - registrationId = results.getDevice(4).getRegistrationId(); - deviceId = results.getDevice(4).getDeviceId(); + signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey(); + preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey(); + pqPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getPqPreKey(); + registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId(); + deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId(); assertEquals(SAMPLE_KEY4, preKey); assertThat(pqPreKey).isNull(); assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4); assertThat(signedPreKey).isNull(); - assertThat(deviceId).isEqualTo(4); + assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4); - verify(KEYS).takeEC(EXISTS_UUID, 1); - verify(KEYS).takePQ(EXISTS_UUID, 1); - verify(KEYS).takeEC(EXISTS_UUID, 2); - verify(KEYS).takePQ(EXISTS_UUID, 2); - verify(KEYS).takeEC(EXISTS_UUID, 4); - verify(KEYS).takePQ(EXISTS_UUID, 4); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2); - verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4); + verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2); + verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4); verifyNoMoreInteractions(KEYS); } @@ -719,7 +743,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull()); + verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), + eq(signedPreKey), isNull()); assertThat(listCaptor.getValue()).containsExactly(preKey); @@ -750,7 +775,8 @@ class KeysControllerTest { ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey)); + verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(), + eq(signedPreKey), eq(pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); @@ -852,7 +878,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull()); + verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), eq(signedPreKey), + isNull()); assertThat(listCaptor.getValue()).containsExactly(preKey); @@ -884,7 +911,8 @@ class KeysControllerTest { ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey)); + verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(), + eq(signedPreKey), eq(pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); @@ -928,7 +956,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull()); + verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), + eq(signedPreKey), isNull()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -953,7 +982,8 @@ class KeysControllerTest { resources.getJerseyTest() .target("/v2/keys") .request() - .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, 2L, AuthHelper.VALID_PASSWORD_3_LINKED)) + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, SAMPLE_DEVICE_ID2, + AuthHelper.VALID_PASSWORD_3_LINKED)) .put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(403); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 34890fc81..4884cbe9b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -135,15 +135,15 @@ class MessageControllerTest { private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111"; private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111"); private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111"); - private static final int SINGLE_DEVICE_ID1 = 1; + private static final byte SINGLE_DEVICE_ID1 = 1; private static final int SINGLE_DEVICE_REG_ID1 = 111; private static final String MULTI_DEVICE_RECIPIENT = "+14152222222"; private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222"); private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222"); - private static final int MULTI_DEVICE_ID1 = 1; - private static final int MULTI_DEVICE_ID2 = 2; - private static final int MULTI_DEVICE_ID3 = 3; + private static final byte MULTI_DEVICE_ID1 = 1; + private static final byte MULTI_DEVICE_ID2 = 2; + private static final byte MULTI_DEVICE_ID3 = 3; private static final int MULTI_DEVICE_REG_ID1 = 222; private static final int MULTI_DEVICE_REG_ID2 = 333; private static final int MULTI_DEVICE_REG_ID3 = 444; @@ -225,7 +225,8 @@ class MessageControllerTest { when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); } - private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) { + private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId, + final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) { final Device device = new Device(); device.setId(id); device.setRegistrationId(registrationId); @@ -526,13 +527,14 @@ class MessageControllerTest { final UUID updatedPniOne = UUID.randomUUID(); List envelopes = List.of( - generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2, + generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, (byte) 2, AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false), - generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2, + generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, + (byte) 2, AuthHelper.VALID_UUID, null, null, 0, true) ); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean())) .thenReturn(Mono.just(new Pair<>(envelopes, false))); final String userAgent = "Test-UA"; @@ -580,13 +582,13 @@ class MessageControllerTest { final long timestampTwo = 313388; final List messages = List.of( - generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), 2, + generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), (byte) 2, AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0), generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, - UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0) + UUID.randomUUID(), (byte) 2, AuthHelper.VALID_UUID, null, null, 0) ); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean())) .thenReturn(Mono.just(new Pair<>(messages, false))); Response response = @@ -606,24 +608,24 @@ class MessageControllerTest { UUID sourceUuid = UUID.randomUUID(); UUID uuid1 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid1, null)) .thenReturn( CompletableFuture.completedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE, - timestamp, sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)))); + timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0)))); UUID uuid2 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid2, null)) .thenReturn( CompletableFuture.completedFuture(Optional.of(generateEnvelope( uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, - System.currentTimeMillis(), sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0)))); + System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0)))); UUID uuid3 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid3, null)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); UUID uuid4 = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid4, null)) + when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid4, null)) .thenReturn(CompletableFuture.failedFuture(new RuntimeException("Oh No"))); Response response = resources.getJerseyTest() @@ -633,7 +635,7 @@ class MessageControllerTest { .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); - verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L), + verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1), eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp)); response = resources.getJerseyTest() @@ -879,7 +881,7 @@ class MessageControllerTest { .request() .header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES)) .put(Entity.entity(new IncomingMessageList( - List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true, + List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true, System.currentTimeMillis()), MediaType.APPLICATION_JSON_TYPE)); @@ -919,7 +921,7 @@ class MessageControllerTest { ); } - private static void writePayloadDeviceId(ByteBuffer bb, long deviceId) { + private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) { long x = deviceId; // write the device-id in the 7-bit varint format we use, least significant bytes first. do { @@ -1155,7 +1157,7 @@ class MessageControllerTest { if (known) { r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]); } else { - r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]); + r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]); } Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]); @@ -1250,7 +1252,7 @@ class MessageControllerTest { SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class)); assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier, - new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))), + new MismatchedDevices(Collections.emptyList(), List.of(MULTI_DEVICE_ID3)))), mismatchedDevices); } @@ -1298,7 +1300,8 @@ class MessageControllerTest { assertEquals(1, staleDevices.size()); assertEquals(serviceIdentifier, staleDevices.get(0).uuid()); - assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices())); + assertEquals(Set.of(MULTI_DEVICE_ID1, MULTI_DEVICE_ID2), + new HashSet<>(staleDevices.get(0).devices().staleDevices())); } private static Stream sendMultiRecipientMessageStaleDevices() { @@ -1380,12 +1383,12 @@ class MessageControllerTest { } private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, - int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { + byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) { return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false); } private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid, - int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) { + byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) { final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder() .setType(MessageProtos.Envelope.Type.forNumber(type)) @@ -1413,14 +1416,14 @@ class MessageControllerTest { private static Recipient genRecipient(Random rng) { UUID u1 = UUID.randomUUID(); // non-null - long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903 + byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127 int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 byte[] perKeyBytes = new byte[48]; // size=48, non-null rng.nextBytes(perKeyBytes); return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes); } - private static void roundTripVarint(long expected, byte [] bytes) throws Exception { + private static void roundTripVarint(byte expected, byte[] bytes) throws Exception { ByteBuffer bb = ByteBuffer.wrap(bytes); writePayloadDeviceId(bb, expected); InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position()); @@ -1434,15 +1437,17 @@ class MessageControllerTest { byte[] bytes = new byte[12]; // some static test cases - for (long i = 1L; i <= 10L; i++) { + for (byte i = 1; i <= 10; i++) { roundTripVarint(i, bytes); } - roundTripVarint(Long.MAX_VALUE, bytes); + roundTripVarint(Byte.MAX_VALUE, bytes); for (int i = 0; i < 1000; i++) { // we need to ensure positive device IDs - long start = rng.nextLong() & Long.MAX_VALUE; - if (start == 0L) start = 1L; + byte start = (byte) rng.nextInt(128); + if (start == 0L) { + start = 1; + } // run the test for this case roundTripVarint(start, bytes); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index 9448870d7..f8f445f10 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -75,12 +75,12 @@ class OutgoingMessageEntityTest { final Account account = new Account(); account.setUuid(UUID.randomUUID()); - IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA"); + IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA"); MessageProtos.Envelope baseEnvelope = message.toEnvelope( new AciServiceIdentifier(UUID.randomUUID()), account, - 123L, + (byte) 123, System.currentTimeMillis(), false, true, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java index e44780ec5..c4580385f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/AccountsGrpcServiceTest.java @@ -170,7 +170,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest { final Account account = invocation.getArgument(0); final Device device = account.getDevice(invocation.getArgument(1)).orElseThrow(); @@ -99,8 +99,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder() .setId(17) .build())); } @ParameterizedTest - @ValueSource(longs = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) - void setDeviceName(final long deviceId) { + @ValueSource(bytes = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) + void setDeviceName(final byte deviceId) { mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId); final Device device = mock(Device.class); @@ -212,7 +212,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest streamBuilder = Stream.builder(); - for (final long deviceId : new long[] { Device.PRIMARY_ID, Device.PRIMARY_ID + 1 }) { + for (final byte deviceId : new byte[]{Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) { streamBuilder.add(Arguments.of(deviceId, SetPushTokenRequest.newBuilder() .setApnsTokenRequest(SetPushTokenRequest.ApnsTokenRequest.newBuilder() @@ -284,7 +284,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest setPushTokenUnchanged() { @@ -323,7 +323,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest authenticatedServiceStub().setPushToken(request)); - verify(accountsManager, never()).updateDevice(any(), anyLong(), any()); + verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); } private static Stream setPushTokenIllegalArgument() { @@ -342,7 +342,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest { + when(accountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(invocation -> { final Account account = invocation.getArgument(0); - final long deviceId = invocation.getArgument(1); + final byte deviceId = invocation.getArgument(1); final Consumer deviceUpdater = invocation.getArgument(2); account.getDevice(deviceId).ifPresent(deviceUpdater); @@ -477,13 +477,16 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest serviceIdentifier.uuid().equals(identifier)))) .thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount))); - final Map ecOneTimePreKeys = new HashMap<>(); - final Map kemPreKeys = new HashMap<>(); - final Map ecSignedPreKeys = new HashMap<>(); + final Map ecOneTimePreKeys = new HashMap<>(); + final Map kemPreKeys = new HashMap<>(); + final Map ecSignedPreKeys = new HashMap<>(); - final Map devices = new HashMap<>(); + final Map devices = new HashMap<>(); - for (final long deviceId : List.of(1, 2)) { + final byte deviceId1 = 1; + final byte deviceId2 = 2; + + for (final byte deviceId : List.of(deviceId1, deviceId2)) { ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey())); kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair)); ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair)); @@ -518,18 +521,18 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2); assertEquals(1, pendingDestinations.size()); - final Optional> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated( + final Optional> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated( pendingDestinations.get(0)); assertTrue(maybeUuidAndDeviceId.isPresent()); assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first()); - assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second()); + assertEquals(DEVICE_ID, maybeUuidAndDeviceId.get().second()); assertTrue( apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty()); @@ -236,8 +235,6 @@ class ApnPushNotificationSchedulerTest { final AccountsManager accountsManager = mock(AccountsManager.class); - final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); - apnPushNotificationScheduler = new ApnPushNotificationScheduler(redisCluster, apnSender, accountsManager, dedicatedThreadCount); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java index 82e70b5cf..7155e96a8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ClientPresenceManagerTest.java @@ -76,7 +76,7 @@ class ClientPresenceManagerTest { @Test void testIsPresent() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId)); @@ -87,7 +87,7 @@ class ClientPresenceManagerTest { @Test void testIsLocallyPresent() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId)); @@ -100,7 +100,7 @@ class ClientPresenceManagerTest { @Test void testLocalDisplacement() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final AtomicInteger displacementCounter = new AtomicInteger(0); final DisplacedPresenceListener displacementListener = connectedElsewhere -> displacementCounter.incrementAndGet(); @@ -117,7 +117,7 @@ class ClientPresenceManagerTest { @Test void testRemoteDisplacement() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final CompletableFuture displaced = new CompletableFuture<>(); @@ -135,7 +135,7 @@ class ClientPresenceManagerTest { @Test void testRemoteDisplacementAfterTopologyChange() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final CompletableFuture displaced = new CompletableFuture<>(); @@ -157,7 +157,7 @@ class ClientPresenceManagerTest { @Test void testClearPresence() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId)); @@ -210,7 +210,7 @@ class ClientPresenceManagerTest { @Test void testInitialPresenceExpiration() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); @@ -225,7 +225,7 @@ class ClientPresenceManagerTest { @Test void testRenewPresence() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final String presenceKey = ClientPresenceManager.getPresenceKey(accountUuid, deviceId); @@ -252,7 +252,7 @@ class ClientPresenceManagerTest { @Test void testExpiredPresence() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); @@ -266,7 +266,7 @@ class ClientPresenceManagerTest { } private void addClientPresence(final String managerId) { - final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7); + final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), (byte) 7); REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> { connection.sync().set(clientPresenceKey, managerId); @@ -278,17 +278,17 @@ class ClientPresenceManagerTest { void testClearAllOnStop() { final int localAccounts = 10; final UUID[] localUuids = new UUID[localAccounts]; - final long[] localDeviceIds = new long[localAccounts]; + final byte[] localDeviceIds = new byte[localAccounts]; for (int i = 0; i < localAccounts; i++) { localUuids[i] = UUID.randomUUID(); - localDeviceIds[i] = i; + localDeviceIds[i] = (byte) i; clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP); } final UUID displacedAccountUuid = UUID.randomUUID(); - final long displacedAccountDeviceId = 7; + final byte displacedAccountDeviceId = 7; clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP); REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> connection.sync() @@ -299,7 +299,7 @@ class ClientPresenceManagerTest { for (int i = 0; i < localAccounts; i++) { localUuids[i] = UUID.randomUUID(); - localDeviceIds[i] = i; + localDeviceIds[i] = (byte) i; assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i])); } @@ -346,7 +346,7 @@ class ClientPresenceManagerTest { @Test void testSetPresentRemotely() { final UUID uuid1 = UUID.randomUUID(); - final long deviceId = 1L; + final byte deviceId = 1; final CompletableFuture displaced = new CompletableFuture<>(); final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null); @@ -360,7 +360,7 @@ class ClientPresenceManagerTest { @Test void testDisconnectPresenceLocally() { final UUID uuid1 = UUID.randomUUID(); - final long deviceId = 1L; + final byte deviceId = 1; final CompletableFuture displaced = new CompletableFuture<>(); final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null); @@ -374,7 +374,7 @@ class ClientPresenceManagerTest { @Test void testDisconnectPresenceRemotely() { final UUID uuid1 = UUID.randomUUID(); - final long deviceId = 1L; + final byte deviceId = 1; final CompletableFuture displaced = new CompletableFuture<>(); final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java index 87cb32832..cb7ca6d3b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/MessageSenderTest.java @@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -42,7 +42,7 @@ class MessageSenderTest { private MessageSender messageSender; private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final long DEVICE_ID = 1L; + private static final byte DEVICE_ID = 1; @BeforeEach void setUp() { @@ -73,7 +73,7 @@ class MessageSenderTest { ArgumentCaptor envelopeArgumentCaptor = ArgumentCaptor.forClass( MessageProtos.Envelope.class); - verify(messagesManager).insert(any(), anyLong(), envelopeArgumentCaptor.capture()); + verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture()); assertTrue(envelopeArgumentCaptor.getValue().getEphemeral()); @@ -87,7 +87,7 @@ class MessageSenderTest { messageSender.sendMessage(account, device, message, true); - verify(messagesManager, never()).insert(any(), anyLong(), any()); + verify(messagesManager, never()).insert(any(), anyByte(), any()); verifyNoInteractions(pushNotificationManager); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java index ceb4c9daf..d24c06985 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ProvisioningManagerTest.java @@ -1,6 +1,16 @@ package org.whispersystems.textsecuregcm.push; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.after; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + import com.google.protobuf.ByteString; +import java.time.Duration; +import java.util.Random; +import java.util.function.Consumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -11,17 +21,6 @@ import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; -import java.time.Duration; -import java.util.Random; -import java.util.function.Consumer; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.after; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.verify; - class ProvisioningManagerTest { private ProvisioningManager provisioningManager; @@ -44,7 +43,7 @@ class ProvisioningManagerTest { @Test void sendProvisioningMessage() { - final ProvisioningAddress address = new ProvisioningAddress("address", 0); + final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0); final byte[] content = new byte[16]; new Random().nextBytes(content); @@ -65,7 +64,7 @@ class ProvisioningManagerTest { @Test void removeListener() { - final ProvisioningAddress address = new ProvisioningAddress("address", 0); + final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0); final byte[] content = new byte[16]; new Random().nextBytes(content); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushLatencyManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushLatencyManagerTest.java index 3fc81e29f..879571663 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/PushLatencyManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/PushLatencyManagerTest.java @@ -35,7 +35,7 @@ class PushLatencyManagerTest { @MethodSource void testTakeRecord(final boolean isVoip, final boolean isUrgent) throws ExecutionException, InterruptedException { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final Instant pushTimestamp = Instant.now(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java index d6fcab4a1..faa1eb993 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountTest.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -85,15 +86,16 @@ class AccountTest { when(agingSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)); when(agingSecondaryDevice.isEnabled()).thenReturn(false); - when(agingSecondaryDevice.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(agingSecondaryDevice.getId()).thenReturn(deviceId2); when(recentSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(1)); when(recentSecondaryDevice.isEnabled()).thenReturn(true); - when(recentSecondaryDevice.getId()).thenReturn(2L); + when(recentSecondaryDevice.getId()).thenReturn(deviceId2); when(oldSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366)); when(oldSecondaryDevice.isEnabled()).thenReturn(false); - when(oldSecondaryDevice.getId()).thenReturn(2L); + when(oldSecondaryDevice.getId()).thenReturn(deviceId2); when(senderKeyCapableDevice.getCapabilities()).thenReturn( new DeviceCapabilities(true, true, false, false)); @@ -143,17 +145,17 @@ class AccountTest { new DeviceCapabilities(true, true, false, false)); when(pniIncapableExpiredDevice.isEnabled()).thenReturn(false); - when(storiesCapableDevice.getId()).thenReturn(1L); + when(storiesCapableDevice.getId()).thenReturn(Device.PRIMARY_ID); when(storiesCapableDevice.getCapabilities()).thenReturn( new DeviceCapabilities(true, true, false, false)); when(storiesCapableDevice.isEnabled()).thenReturn(true); - when(storiesCapableDevice.getId()).thenReturn(2L); + when(storiesCapableDevice.getId()).thenReturn(deviceId2); when(storiesIncapableDevice.getCapabilities()).thenReturn( new DeviceCapabilities(true, true, false, false)); when(storiesIncapableDevice.isEnabled()).thenReturn(true); - when(storiesCapableDevice.getId()).thenReturn(3L); + when(storiesCapableDevice.getId()).thenReturn((byte) 3); when(storiesIncapableExpiredDevice.getCapabilities()).thenReturn( new DeviceCapabilities(true, true, false, false)); when(storiesIncapableExpiredDevice.isEnabled()).thenReturn(false); @@ -192,10 +194,11 @@ class AccountTest { when(disabledPrimaryDevice.isEnabled()).thenReturn(false); when(disabledLinkedDevice.isEnabled()).thenReturn(false); - when(enabledPrimaryDevice.getId()).thenReturn(1L); - when(enabledLinkedDevice.getId()).thenReturn(2L); - when(disabledPrimaryDevice.getId()).thenReturn(1L); - when(disabledLinkedDevice.getId()).thenReturn(2L); + when(enabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + final byte deviceId2 = 2; + when(enabledLinkedDevice.getId()).thenReturn(deviceId2); + when(disabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + when(disabledLinkedDevice.getId()).thenReturn(deviceId2); assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice)).isEnabled()); assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice, enabledLinkedDevice)).isEnabled()); @@ -214,15 +217,15 @@ class AccountTest { final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class); final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class); - when(transferCapablePrimaryDevice.getId()).thenReturn(1L); + when(transferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(transferCapablePrimaryDevice.isPrimary()).thenReturn(true); when(transferCapablePrimaryDevice.getCapabilities()).thenReturn(transferCapabilities); - when(nonTransferCapablePrimaryDevice.getId()).thenReturn(1L); + when(nonTransferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(nonTransferCapablePrimaryDevice.isPrimary()).thenReturn(true); when(nonTransferCapablePrimaryDevice.getCapabilities()).thenReturn(nonTransferCapabilities); - when(transferCapableLinkedDevice.getId()).thenReturn(2L); + when(transferCapableLinkedDevice.getId()).thenReturn((byte) 2); when(transferCapableLinkedDevice.isPrimary()).thenReturn(false); when(transferCapableLinkedDevice.getCapabilities()).thenReturn(transferCapabilities); @@ -311,21 +314,31 @@ class AccountTest { final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), devices, new byte[0]); - assertThat(account.getNextDeviceId()).isEqualTo(2L); + final byte deviceId2 = 2; + assertThat(account.getNextDeviceId()).isEqualTo(deviceId2); - account.addDevice(createDevice(2L)); + account.addDevice(createDevice(deviceId2)); - assertThat(account.getNextDeviceId()).isEqualTo(3L); + final byte deviceId3 = 3; + assertThat(account.getNextDeviceId()).isEqualTo(deviceId3); - account.addDevice(createDevice(3L)); + account.addDevice(createDevice(deviceId3)); - setEnabled(account.getDevice(2L).orElseThrow(), false); + setEnabled(account.getDevice(deviceId2).orElseThrow(), false); - assertThat(account.getNextDeviceId()).isEqualTo(4L); + assertThat(account.getNextDeviceId()).isEqualTo((byte) 4); - account.removeDevice(2L); + account.removeDevice(deviceId2); - assertThat(account.getNextDeviceId()).isEqualTo(2L); + assertThat(account.getNextDeviceId()).isEqualTo(deviceId2); + + while (account.getNextDeviceId() < Device.MAXIMUM_DEVICE_ID) { + account.addDevice(createDevice(account.getNextDeviceId())); + } + + account.addDevice(createDevice(Device.MAXIMUM_DEVICE_ID)); + + assertThatThrownBy(account::getNextDeviceId).isInstanceOf(RuntimeException.class); } @Test @@ -399,7 +412,7 @@ class AccountTest { final Device disabledPrimary = mock(Device.class); when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID); - final long linked1DeviceId = Device.PRIMARY_ID + 1; + final byte linked1DeviceId = Device.PRIMARY_ID + 1; final Device enabledLinked1 = mock(Device.class); when(enabledLinked1.isEnabled()).thenReturn(true); when(enabledLinked1.getId()).thenReturn(linked1DeviceId); @@ -407,7 +420,7 @@ class AccountTest { final Device disabledLinked1 = mock(Device.class); when(disabledLinked1.getId()).thenReturn(linked1DeviceId); - final long linked2DeviceId = Device.PRIMARY_ID + 2; + final byte linked2DeviceId = Device.PRIMARY_ID + 2; final Device enabledLinked2 = mock(Device.class); when(enabledLinked2.isEnabled()).thenReturn(true); when(enabledLinked2.getId()).thenReturn(linked2DeviceId); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 8b879c4be..44f61aea7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -178,8 +178,8 @@ class AccountsManagerChangeNumberIntegrationTest { final UUID originalPni = account.getPhoneNumberIdentifier(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); - final Map registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); + final Map preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); + final Map registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 0f12d3315..f2063a4f3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -141,8 +141,8 @@ class AccountsManagerConcurrentModificationIntegrationTest { accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()), a -> { a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - a.removeDevice(1); - a.addDevice(DevicesHelper.createDevice(1)); + a.removeDevice(Device.PRIMARY_ID); + a.addDevice(DevicesHelper.createDevice(Device.PRIMARY_ID)); }); uuid = account.getUuid(); @@ -212,7 +212,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { }, mutationExecutor); } - private CompletableFuture modifyDevice(final UUID uuid, final long deviceId, final Consumer deviceMutation) { + private CompletableFuture modifyDevice(final UUID uuid, final byte deviceId, final Consumer deviceMutation) { return CompletableFuture.runAsync(() -> { final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 2a51fa5b6..25df81c65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -876,7 +876,7 @@ class AccountsManagerTest { enabledDevice.setFetchesMessages(true); enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair())); enabledDevice.setLastSeen(System.currentTimeMillis()); - final long deviceId = account.getNextDeviceId(); + final byte deviceId = account.getNextDeviceId(); enabledDevice.setId(deviceId); account.addDevice(enabledDevice); @@ -909,7 +909,7 @@ class AccountsManagerTest { enabledDevice.setFetchesMessages(true); enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair())); enabledDevice.setLastSeen(System.currentTimeMillis()); - final long deviceId = account.getNextDeviceId(); + final byte deviceId = account.getNextDeviceId(); enabledDevice.setId(deviceId); account.addDevice(enabledDevice); @@ -1064,7 +1064,8 @@ class AccountsManagerTest { final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); assertThrows(IllegalArgumentException.class, () -> accountsManager.changeNumber( - account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)), + account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), + Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of((byte) 1, 101)), "AccountsManager should not allow use of changeNumber with new PNI keys but without changing number"); verify(accounts, never()).update(any()); @@ -1107,24 +1108,26 @@ class AccountsManagerTest { final UUID uuid = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); + final byte deviceId2 = 2; + final byte deviceId3 = 3; final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), - 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - final Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), + deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); + final Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L, 3L))); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3))); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); final List devices = List.of( - DevicesHelper.createDevice(1L, 0L, 101), - DevicesHelper.createDevice(2L, 0L, 102), - DevicesHelper.createDisabledDevice(3L, 103)); + DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102), + DevicesHelper.createDisabledDevice(deviceId3, 103)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final Account updatedAccount = accountsManager.changeNumber( account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds); @@ -1140,7 +1143,8 @@ class AccountsManagerTest { verify(keysManager).delete(originalPni); verify(keysManager).getPqEnabledDevices(uuid); verify(keysManager).storeEcSignedPreKeys(newPni, newSignedKeys); - verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); + verify(keysManager).storePqLastResort(eq(newPni), + eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID)))); verifyNoMoreInteractions(keysManager); } @@ -1153,19 +1157,22 @@ class AccountsManagerTest { final UUID uuid = UUID.randomUUID(); final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); + final byte deviceId2 = 2; final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair)); - final Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)); + final Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn( + CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID))); - final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); assertThrows(MismatchedDevicesException.class, () -> accountsManager.changeNumber( @@ -1189,18 +1196,20 @@ class AccountsManagerTest { @Test void testPniUpdate() throws MismatchedDevicesException { final String number = "+14152222222"; + final byte deviceId2 = 2; - List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - Map oldSignedPreKeys = account.getDevices().stream() + Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -1217,7 +1226,7 @@ class AccountsManagerTest { assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)))); - assertEquals(Map.of(1L, 101, 2L, 102), + assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); // PNI stuff should @@ -1236,26 +1245,29 @@ class AccountsManagerTest { @Test void testPniPqUpdate() throws MismatchedDevicesException { final String number = "+14152222222"; + final byte deviceId2 = 2; - List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), - 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), + deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); + Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L))); + when(keysManager.getPqEnabledDevices(oldPni)).thenReturn( + CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID))); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - Map oldSignedPreKeys = account.getDevices().stream() + Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -1270,7 +1282,7 @@ class AccountsManagerTest { assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)))); - assertEquals(Map.of(1L, 101, 2L, 102), + assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); // PNI keys should @@ -1287,23 +1299,26 @@ class AccountsManagerTest { verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys); // only the pq key for the already-pq-enabled device should be saved - verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); + verify(keysManager).storePqLastResort(eq(oldPni), + eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID)))); } @Test void testPniNonPqToPqUpdate() throws MismatchedDevicesException { final String number = "+14152222222"; + final byte deviceId2 = 2; - List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), - 2L, KeysHelper.signedKEMPreKey(4, identityKeyPair)); - Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), + deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair)); + Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); @@ -1312,7 +1327,7 @@ class AccountsManagerTest { when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - Map oldSignedPreKeys = account.getDevices().stream() + Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -1327,7 +1342,7 @@ class AccountsManagerTest { assertNull(updatedAccount.getIdentityKey(IdentityType.ACI)); assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)))); - assertEquals(Map.of(1L, 101, 2L, 102), + assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102), updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId))); // PNI keys should @@ -1348,19 +1363,21 @@ class AccountsManagerTest { @Test void testPniUpdate_incompleteKeys() { final String number = "+14152222222"; - - List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final byte deviceId2 = 2; + final byte deviceId3 = 3; + List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 2L, KeysHelper.signedECPreKey(1, identityKeyPair), - 3L, KeysHelper.signedECPreKey(2, identityKeyPair)); - Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + deviceId2, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair)); + Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - Map oldSignedPreKeys = account.getDevices().stream() + Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); @@ -1375,21 +1392,22 @@ class AccountsManagerTest { @Test void testPniPqUpdate_incompleteKeys() { final String number = "+14152222222"; - - List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); + final byte deviceId2 = 2; + List devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101), + DevicesHelper.createDevice(deviceId2, 0L, 102)); Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - final Map newSignedKeys = Map.of( - 1L, KeysHelper.signedECPreKey(1, identityKeyPair), - 2L, KeysHelper.signedECPreKey(2, identityKeyPair)); - final Map newSignedPqKeys = Map.of( - 1L, KeysHelper.signedKEMPreKey(3, identityKeyPair)); - Map newRegistrationIds = Map.of(1L, 201, 2L, 202); + final Map newSignedKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair)); + final Map newSignedPqKeys = Map.of( + Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair)); + Map newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202); UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - Map oldSignedPreKeys = account.getDevices().stream() + Map oldSignedPreKeys = account.getDevices().stream() .collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 8b677ce9b..d9dd959d2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -75,6 +76,9 @@ import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class AccountsTest { + private static final byte DEVICE_ID_1 = 1; + private static final byte DEVICE_ID_2 = 2; + private static final String BASE_64_URL_USERNAME_HASH_1 = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE"; private static final String BASE_64_URL_USERNAME_HASH_2 = "NLUom-CHwtemcdvOTTXdmXmzRIV7F05leS8lwkVK_vc"; private static final String BASE_64_URL_ENCRYPTED_USERNAME_1 = "md1votbj9r794DsqTNrBqA"; @@ -156,7 +160,7 @@ class AccountsTest { @Test void testStore() { - Device device = generateDevice(1); + Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); boolean freshUser = accounts.create(account); @@ -179,7 +183,7 @@ class AccountsTest { void testStoreRecentlyDeleted() { final UUID originalUuid = UUID.randomUUID(); - Device device = generateDevice(1); + Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device)); boolean freshUser = accounts.create(account); @@ -205,7 +209,7 @@ class AccountsTest { @Test void testStoreMulti() { - final List devices = List.of(generateDevice(1), generateDevice(2)); + final List devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices); accounts.create(account); @@ -218,13 +222,13 @@ class AccountsTest { @Test void testRetrieve() { - final List devicesFirst = List.of(generateDevice(1), generateDevice(2)); + final List devicesFirst = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); UUID uuidFirst = UUID.randomUUID(); UUID pniFirst = UUID.randomUUID(); Account accountFirst = generateAccount("+14151112222", uuidFirst, pniFirst, devicesFirst); - final List devicesSecond = List.of(generateDevice(1), generateDevice(2)); + final List devicesSecond = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); UUID uuidSecond = UUID.randomUUID(); UUID pniSecond = UUID.randomUUID(); @@ -263,7 +267,7 @@ class AccountsTest { @Test void testRetrieveNoPni() throws JsonProcessingException { - final List devices = List.of(generateDevice(1), generateDevice(2)); + final List devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); final UUID uuid = UUID.randomUUID(); final Account account = generateAccount("+14151112222", uuid, null, devices); @@ -321,7 +325,7 @@ class AccountsTest { @Test void testOverwrite() { - Device device = generateDevice(1); + Device device = generateDevice(DEVICE_ID_1); UUID firstUuid = UUID.randomUUID(); UUID firstPni = UUID.randomUUID(); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device)); @@ -346,7 +350,7 @@ class AccountsTest { UUID secondUuid = UUID.randomUUID(); - device = generateDevice(1); + device = generateDevice(DEVICE_ID_1); account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device)); final boolean freshUser = accounts.create(account); @@ -356,7 +360,7 @@ class AccountsTest { assertPhoneNumberConstraintExists("+14151112222", firstUuid); assertPhoneNumberIdentifierConstraintExists(firstPni, firstUuid); - device = generateDevice(1); + device = generateDevice(DEVICE_ID_1); Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device)); assertThatThrownBy(() -> accounts.create(invalidAccount)); @@ -364,7 +368,7 @@ class AccountsTest { @Test void testUpdate() { - Device device = generateDevice (1 ); + Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); accounts.create(account); @@ -389,7 +393,7 @@ class AccountsTest { assertThat(retrieved.isPresent()).isTrue(); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); - device = generateDevice(1); + device = generateDevice(DEVICE_ID_1); Account unknownAccount = generateAccount("+14151113333", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); assertThatThrownBy(() -> accounts.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class); @@ -452,10 +456,10 @@ class AccountsTest { @Test void testDelete() { - final Device deletedDevice = generateDevice(1); + final Device deletedDevice = generateDevice(DEVICE_ID_1); final Account deletedAccount = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(deletedDevice)); - final Device retainedDevice = generateDevice(1); + final Device retainedDevice = generateDevice(DEVICE_ID_1); final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(), UUID.randomUUID(), List.of(retainedDevice)); @@ -485,7 +489,7 @@ class AccountsTest { { final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(), - UUID.randomUUID(), List.of(generateDevice(1))); + UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); final boolean freshUser = accounts.create(recreatedAccount); @@ -501,7 +505,7 @@ class AccountsTest { @Test void testMissing() { - Device device = generateDevice (1 ); + Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); accounts.create(account); @@ -518,7 +522,7 @@ class AccountsTest { assertThat(accounts.getByAccountIdentifierAsync(UUID.randomUUID()).join()).isEmpty(); final Account account = - generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1))); + generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); accounts.create(account); @@ -530,7 +534,7 @@ class AccountsTest { assertThat(accounts.getByPhoneNumberIdentifierAsync(UUID.randomUUID()).join()).isEmpty(); final Account account = - generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1))); + generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); accounts.create(account); @@ -544,7 +548,7 @@ class AccountsTest { assertThat(accounts.getByE164Async(e164).join()).isEmpty(); final Account account = - generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1))); + generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); accounts.create(account); @@ -553,7 +557,7 @@ class AccountsTest { @Test void testCanonicallyDiscoverableSet() { - Device device = generateDevice(1); + Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); account.setDiscoverableByPhoneNumber(false); accounts.create(account); @@ -576,7 +580,7 @@ class AccountsTest { final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); - final Device device = generateDevice(1); + final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); accounts.create(account); @@ -631,10 +635,10 @@ class AccountsTest { final UUID originalPni = UUID.randomUUID(); final UUID targetPni = UUID.randomUUID(); - final Device existingDevice = generateDevice(1); + final Device existingDevice = generateDevice(DEVICE_ID_1); final Account existingAccount = generateAccount(targetNumber, UUID.randomUUID(), targetPni, List.of(existingDevice)); - final Device device = generateDevice(1); + final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); accounts.create(account); @@ -653,7 +657,7 @@ class AccountsTest { final String originalNumber = "+14151112222"; final String targetNumber = "+14151113333"; - final Device device = generateDevice(1); + final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device)); accounts.create(account); @@ -969,7 +973,48 @@ class AccountsTest { assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isPresent(); } - private static Device generateDevice(long id) { + @Test + public void testInvalidDeviceIdDeserialization() throws Exception { + final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); + final Device device2 = generateDevice((byte) 64); + account.addDevice(device2); + + accounts.create(account); + + final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder() + .tableName(Tables.ACCOUNTS.tableName()) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .build()).join(); + + final Map accountData = SystemMapper.jsonMapper() + .readValue(response.item().get(Accounts.ATTR_ACCOUNT_DATA).b().asByteArray(), Map.class); + + final List> devices = (List>) accountData.get("devices"); + assertEquals(Integer.valueOf(device2.getId()), devices.get(1).get("id")); + + devices.get(1).put("id", Byte.MAX_VALUE + 5); + + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().updateItem(UpdateItemRequest.builder() + .tableName(Tables.ACCOUNTS.tableName()) + .key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid()))) + .updateExpression("SET #data = :data") + .expressionAttributeNames(Map.of("#data", Accounts.ATTR_ACCOUNT_DATA)) + .expressionAttributeValues( + Map.of(":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(accountData)))) + .build()).join(); + + final CompletionException e = assertThrows(CompletionException.class, + () -> accounts.getByAccountIdentifierAsync(account.getUuid()).join()); + + Throwable cause = e.getCause(); + while (cause.getCause() != null) { + cause = cause.getCause(); + } + + assertInstanceOf(DeviceIdDeserializer.DeviceIdDeserializationException.class, cause); + } + + private static Device generateDevice(byte id) { return DevicesHelper.createDevice(id); } @@ -979,7 +1024,7 @@ class AccountsTest { } private static Account generateAccount(String number, UUID uuid, final UUID pni) { - Device device = generateDevice(1); + Device device = generateDevice(DEVICE_ID_1); return generateAccount(number, uuid, pni, List.of(device)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java index 5463988e8..0d895efd3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ChangeNumberManagerTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -68,7 +69,7 @@ public class ChangeNumberManagerTest { when(updatedAccount.getNumber()).thenReturn(number); when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni); when(updatedAccount.getDevices()).thenReturn(devices); - for (long i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); when(updatedAccount.getDevice(i)).thenReturn(d); } @@ -87,7 +88,7 @@ public class ChangeNumberManagerTest { when(updatedAccount.getUuid()).thenReturn(uuid); when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni); when(updatedAccount.getDevices()).thenReturn(devices); - for (long i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Optional d = account.getDevice(i); when(updatedAccount.getDevice(i)).thenReturn(d); } @@ -102,7 +103,7 @@ public class ChangeNumberManagerTest { when(account.getNumber()).thenReturn("+18005551234"); changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null); verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null); - verify(accountsManager, never()).updateDevice(any(), eq(1L), any()); + verify(accountsManager, never()).updateDevice(any(), anyByte(), any()); verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false)); } @@ -112,7 +113,8 @@ public class ChangeNumberManagerTest { when(account.getNumber()).thenReturn("+18005551234"); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap()); verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap()); @@ -133,18 +135,21 @@ public class ChangeNumberManagerTest { final Device d2 = mock(Device.class); when(d2.isEnabled()).thenReturn(true); - when(d2.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(d2.getId()).thenReturn(deviceId2); - when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 19); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); @@ -177,19 +182,23 @@ public class ChangeNumberManagerTest { final Device d2 = mock(Device.class); when(d2.isEnabled()).thenReturn(true); - when(d2.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(d2.getId()).thenReturn(deviceId2); - when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 19); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), + (byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -220,19 +229,23 @@ public class ChangeNumberManagerTest { final Device d2 = mock(Device.class); when(d2.isEnabled()).thenReturn(true); - when(d2.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(d2.getId()).thenReturn(deviceId2); - when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 19); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), + (byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -261,18 +274,21 @@ public class ChangeNumberManagerTest { final Device d2 = mock(Device.class); when(d2.isEnabled()).thenReturn(true); - when(d2.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(d2.getId()).thenReturn(deviceId2); - when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 19); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds); @@ -301,19 +317,23 @@ public class ChangeNumberManagerTest { final Device d2 = mock(Device.class); when(d2.isEnabled()).thenReturn(true); - when(d2.getId()).thenReturn(2L); + final byte deviceId2 = 2; + when(d2.getId()).thenReturn(deviceId2); - when(account.getDevice(2L)).thenReturn(Optional.of(d2)); + when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2)); when(account.getDevices()).thenReturn(List.of(d2)); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); - final Map prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); - final Map pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 19); + final Map prekeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair)); + final Map pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), + (byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19); final IncomingMessage msg = mock(IncomingMessage.class); - when(msg.destinationDeviceId()).thenReturn(2L); + when(msg.destinationDeviceId()).thenReturn(deviceId2); when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1})); changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds); @@ -338,11 +358,11 @@ public class ChangeNumberManagerTest { final List devices = new ArrayList<>(); - for (int i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Device device = mock(Device.class); - when(device.getId()).thenReturn((long) i); + when(device.getId()).thenReturn(i); when(device.isEnabled()).thenReturn(true); - when(device.getRegistrationId()).thenReturn(i); + when(device.getRegistrationId()).thenReturn((int) i); devices.add(device); when(account.getDevice(i)).thenReturn(Optional.of(device)); @@ -350,15 +370,21 @@ public class ChangeNumberManagerTest { when(account.getDevices()).thenReturn(devices); + final byte destinationDeviceId2 = 2; + final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, 2, 1, "foo"), - new IncomingMessage(1, 3, 1, "foo")); + new IncomingMessage(1, destinationDeviceId2, 1, "foo"), + new IncomingMessage(1, destinationDeviceId3, 1, "foo")); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); - final Map preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + final Map preKeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), + destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47, + destinationDeviceId3, 89); assertThrows(StaleDevicesException.class, () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds)); @@ -371,11 +397,11 @@ public class ChangeNumberManagerTest { final List devices = new ArrayList<>(); - for (int i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Device device = mock(Device.class); - when(device.getId()).thenReturn((long) i); + when(device.getId()).thenReturn(i); when(device.isEnabled()).thenReturn(true); - when(device.getRegistrationId()).thenReturn(i); + when(device.getRegistrationId()).thenReturn((int) i); devices.add(device); when(account.getDevice(i)).thenReturn(Optional.of(device)); @@ -383,15 +409,21 @@ public class ChangeNumberManagerTest { when(account.getDevices()).thenReturn(devices); + final byte destinationDeviceId2 = 2; + final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, 2, 1, "foo"), - new IncomingMessage(1, 3, 1, "foo")); + new IncomingMessage(1, destinationDeviceId2, 1, "foo"), + new IncomingMessage(1, destinationDeviceId3, 1, "foo")); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey(); - final Map preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); - final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + final Map preKeys = Map.of(Device.PRIMARY_ID, + KeysHelper.signedECPreKey(1, pniIdentityKeyPair), + destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), + destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair)); + final Map registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47, + destinationDeviceId3, 89); assertThrows(StaleDevicesException.class, () -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds)); @@ -404,11 +436,11 @@ public class ChangeNumberManagerTest { final List devices = new ArrayList<>(); - for (int i = 1; i <= 3; i++) { + for (byte i = 1; i <= 3; i++) { final Device device = mock(Device.class); - when(device.getId()).thenReturn((long) i); + when(device.getId()).thenReturn(i); when(device.isEnabled()).thenReturn(true); - when(device.getRegistrationId()).thenReturn(i); + when(device.getRegistrationId()).thenReturn((int) i); devices.add(device); when(account.getDevice(i)).thenReturn(Optional.of(device)); @@ -416,11 +448,13 @@ public class ChangeNumberManagerTest { when(account.getDevices()).thenReturn(devices); + final byte destinationDeviceId2 = 2; + final byte destinationDeviceId3 = 3; final List messages = List.of( - new IncomingMessage(1, 2, 2, "foo"), - new IncomingMessage(1, 3, 3, "foo")); + new IncomingMessage(1, destinationDeviceId2, 2, "foo"), + new IncomingMessage(1, destinationDeviceId3, 3, "foo")); - final Map registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89); + final Map registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89); assertThrows(IllegalArgumentException.class, () -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), null, null, messages, registrationIds)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index dbd0abcd9..54d8118ad 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -40,7 +40,7 @@ class KeysManagerTest { Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final long DEVICE_ID = 1L; + private static final byte DEVICE_ID = 1; private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); @@ -169,7 +169,8 @@ class KeysManagerTest { generateTestKEMSignedPreKey(6)) .join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, + final byte deviceId2 = DEVICE_ID + 1; + keysManager.store(ACCOUNT_UUID, deviceId2, List.of(generateTestPreKey(7)), List.of(generateTestKEMSignedPreKey(8)), generateTestECSignedPreKey(9), @@ -180,10 +181,10 @@ class KeysManagerTest { assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); keysManager.delete(ACCOUNT_UUID).join(); @@ -191,10 +192,10 @@ class KeysManagerTest { assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); + assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); } @Test @@ -206,7 +207,8 @@ class KeysManagerTest { generateTestKEMSignedPreKey(6)) .join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, + final byte deviceId2 = DEVICE_ID + 1; + keysManager.store(ACCOUNT_UUID, deviceId2, List.of(generateTestPreKey(7)), List.of(generateTestKEMSignedPreKey(8)), generateTestECSignedPreKey(9), @@ -217,10 +219,10 @@ class KeysManagerTest { assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join(); @@ -228,10 +230,10 @@ class KeysManagerTest { assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); } @Test @@ -240,21 +242,29 @@ class KeysManagerTest { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - keysManager.storePqLastResort( - ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join(); - assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); - assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId()); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent()); + final byte deviceId2 = 2; + final byte deviceId3 = 3; keysManager.storePqLastResort( ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join(); + Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), (byte) 2, + KeysHelper.signedKEMPreKey(2, identityKeyPair))).join(); + assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size()); + assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent()); + + keysManager.storePqLastResort( + ACCOUNT_UUID, + Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), deviceId3, + KeysHelper.signedKEMPreKey(4, identityKeyPair))).join(); assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates"); - assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); - assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone"); - assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(), + "storing new last-resort keys should overwrite old ones"); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId(), + "storing new last-resort keys should leave untouched ones alone"); + assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(), + "storing new last-resort keys should overwrite old ones"); } @Test @@ -262,11 +272,14 @@ class KeysManagerTest { final ECKeyPair identityKeyPair = Curve.generateKeyPair(); keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join(); + keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), null, null, null, + KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); + keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), null, + List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)) + .join(); + keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 3), null, null, null, null).join(); assertIterableEquals( - Set.of(DEVICE_ID + 1, DEVICE_ID + 2), + Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)), Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join())); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index af832d581..469886df2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -124,17 +124,19 @@ class MessagePersisterIntegrationTest { final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp); - messagesCache.insert(messageGuid, account.getUuid(), 1, message); + messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message); expectedMessages.add(message); } REDIS_CLUSTER_EXTENSION.getRedisCluster() .useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, - String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1))); + String.valueOf( + SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), Device.PRIMARY_ID)) - 1))); final AtomicBoolean messagesPersisted = new AtomicBoolean(false); - messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() { + messagesManager.addMessageAvailabilityListener(account.getUuid(), Device.PRIMARY_ID, + new MessageAvailabilityListener() { @Override public boolean handleNewMessagesAvailable() { return true; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 70259bfb8..df30773d6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -9,8 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; @@ -61,7 +61,7 @@ class MessagePersisterTest { private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; - private static final long DESTINATION_DEVICE_ID = 7; + private static final byte DESTINATION_DEVICE_ID = 7; private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); @@ -90,9 +90,9 @@ class MessagePersisterTest { messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY, 1); - doAnswer(invocation -> { + when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> { final UUID destinationUuid = invocation.getArgument(0); - final long destinationDeviceId = invocation.getArgument(1); + final byte destinationDeviceId = invocation.getArgument(1); final List messages = invocation.getArgument(2); messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); @@ -101,8 +101,8 @@ class MessagePersisterTest { messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get(); } - return null; - }).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any()); + return messages.size(); + }); } @AfterEach @@ -153,7 +153,7 @@ class MessagePersisterTest { messagePersister.persistNextQueues(now); - verify(messagesDynamoDb, never()).store(any(), any(), anyLong()); + verify(messagesDynamoDb, never()).store(any(), any(), anyByte()); } @Test @@ -166,7 +166,7 @@ class MessagePersisterTest { for (int i = 0; i < queueCount; i++) { final String queueName = generateRandomQueueNameForSlot(slot); final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName); - final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName); + final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queueName); final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10); final Account account = mock(Account.class); @@ -183,7 +183,7 @@ class MessagePersisterTest { final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong()); + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyByte()); assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @@ -219,7 +219,7 @@ class MessagePersisterTest { setNextSlotToPersist(SlotHash.getSlot(queueName)); // returning `0` indicates something not working correctly - when(messagesManager.persistMessages(any(UUID.class), anyLong(), anyList())).thenReturn(0); + when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenReturn(0); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> assertThrows(MessagePersistenceException.class, @@ -228,22 +228,23 @@ class MessagePersisterTest { @SuppressWarnings("SameParameterValue") private static String generateRandomQueueNameForSlot(final int slot) { - final UUID uuid = UUID.randomUUID(); - final String queueNameBase = "user_queue::{" + uuid + "::"; + while (true) { - for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) { - final String queueName = queueNameBase + deviceId + "}"; + final UUID uuid = UUID.randomUUID(); + final String queueNameBase = "user_queue::{" + uuid + "::"; - if (SlotHash.getSlot(queueName) == slot) { - return queueName; + for (byte deviceId = 1; deviceId < Device.MAXIMUM_DEVICE_ID; deviceId++) { + final String queueName = queueNameBase + deviceId + "}"; + + if (SlotHash.getSlot(queueName) == slot) { + return queueName; + } } } - - throw new IllegalStateException("Could not find a queue name for slot " + slot); } - private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount, + private void insertMessages(final UUID accountUuid, final byte deviceId, final int messageCount, final Instant firstMessageTimestamp) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index c9aff667d..9c7e7a559 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -85,7 +85,7 @@ class MessagesCacheTest { private static final UUID DESTINATION_UUID = UUID.randomUUID(); - private static final int DESTINATION_DEVICE_ID = 7; + private static final byte DESTINATION_DEVICE_ID = 7; @BeforeEach void setUp() throws Exception { @@ -311,7 +311,7 @@ class MessagesCacheTest { void testClearQueueForDevice(final boolean sealedSender) { final int messageCount = 100; - for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { + for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); @@ -323,7 +323,7 @@ class MessagesCacheTest { messagesCache.clear(DESTINATION_UUID, DESTINATION_DEVICE_ID).join(); assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(messageCount, get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); + assertEquals(messageCount, get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount).size()); } @ParameterizedTest @@ -331,7 +331,7 @@ class MessagesCacheTest { void testClearQueueForAccount(final boolean sealedSender) { final int messageCount = 100; - for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { + for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (int i = 0; i < messageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); @@ -343,7 +343,7 @@ class MessagesCacheTest { messagesCache.clear(DESTINATION_UUID).join(); assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); - assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount)); + assertEquals(Collections.emptyList(), get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount)); } @Test @@ -531,7 +531,7 @@ class MessagesCacheTest { }); } - private List get(final UUID destinationUuid, final long destinationDeviceId, + private List get(final UUID destinationUuid, final byte destinationDeviceId, final int messageCount) { return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId)) .take(messageCount, true) @@ -605,7 +605,7 @@ class MessagesCacheTest { .thenReturn(Flux.from(emptyFinalPagePublisher)) .thenReturn(Flux.empty()); - final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), 1L); + final Flux allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID); // Why initialValue = 3? // 1. messagesCache.getAllMessages() above produces the first call @@ -691,7 +691,7 @@ class MessagesCacheTest { when(asyncCommands.evalsha(any(), any(), any(), any())) .thenReturn((RedisFuture) removeSuccess); - final Publisher allMessages = messagesCache.get(UUID.randomUUID(), 1L); + final Publisher allMessages = messagesCache.get(UUID.randomUUID(), Device.PRIMARY_ID); StepVerifier.setDefaultTimeout(Duration.ofSeconds(5)); @@ -752,7 +752,7 @@ class MessagesCacheTest { .setDestinationUuid(UUID.randomUUID().toString()); if (!sealedSender) { - envelopeBuilder.setSourceDevice(random.nextInt(256)) + envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1) .setSourceUuid(UUID.randomUUID().toString()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java index 0f3c7dfca..29e731004 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java @@ -98,7 +98,7 @@ class MessagesDynamoDbTest { @Test void testSimpleFetchAfterInsert() { final UUID destinationUuid = UUID.randomUUID(); - final int destinationDeviceId = random.nextInt(255) + 1; + final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId); final List messagesStored = load(destinationUuid, destinationDeviceId, @@ -116,11 +116,12 @@ class MessagesDynamoDbTest { @ValueSource(ints = {10, 100, 100, 1_000, 3_000}) void testLoadManyAfterInsert(final int messageCount) { final UUID destinationUuid = UUID.randomUUID(); - final int destinationDeviceId = random.nextInt(255) + 1; + final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); final List messages = new ArrayList<>(messageCount); for (int i = 0; i < messageCount; i++) { - messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i)); + messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000, + "message " + i)); } messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); @@ -148,18 +149,20 @@ class MessagesDynamoDbTest { void testLimitedLoad() { final int messageCount = 200; final UUID destinationUuid = UUID.randomUUID(); - final int destinationDeviceId = random.nextInt(255) + 1; + final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1); final List messages = new ArrayList<>(messageCount); for (int i = 0; i < messageCount; i++) { - messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i)); + messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000, + "message " + i)); } messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); final int messageLoadLimit = 100; final int halfOfMessageLoadLimit = messageLoadLimit / 2; - final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit); + final Publisher fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, + messageLoadLimit); StepVerifier.setDefaultTimeout(Duration.ofSeconds(10)); @@ -170,7 +173,7 @@ class MessagesDynamoDbTest { .thenRequest(halfOfMessageLoadLimit) .expectNextCount(halfOfMessageLoadLimit) // the first 100 should be fetched and buffered, but further requests should fail - .then(() -> DYNAMO_DB_EXTENSION.stopServer()) + .then(DYNAMO_DB_EXTENSION::stopServer) .thenRequest(halfOfMessageLoadLimit) .expectNextCount(halfOfMessageLoadLimit) // we’ve consumed all the buffered messages, so a single request will fail @@ -183,22 +186,23 @@ class MessagesDynamoDbTest { void testDeleteForDestination() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); + final byte deviceId2 = 2; + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, deviceId2); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid).join(); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @@ -206,23 +210,26 @@ class MessagesDynamoDbTest { void testDeleteForDestinationDevice() { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); + final byte destinationDeviceId2 = 2; + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); - messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2).join(); + messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, destinationDeviceId2).join(); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty(); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .isEmpty(); + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); } @@ -230,15 +237,17 @@ class MessagesDynamoDbTest { void testDeleteMessageByDestinationAndGuid() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); + final byte destinationDeviceId2 = 2; + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); final Optional deletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid( @@ -247,11 +256,12 @@ class MessagesDynamoDbTest { assertThat(deletedMessage).isPresent(); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); final Optional alreadyDeletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid( @@ -266,29 +276,32 @@ class MessagesDynamoDbTest { void testDeleteSingleMessage() throws Exception { final UUID destinationUuid = UUID.randomUUID(); final UUID secondDestinationUuid = UUID.randomUUID(); - messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1); - messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2); + messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID); + messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID); + final byte destinationDeviceId2 = 2; + messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .hasSize(1).element(0).isEqualTo(MESSAGE2); - messagesDynamoDb.deleteMessage(secondDestinationUuid, 1, + messagesDynamoDb.deleteMessage(secondDestinationUuid, Device.PRIMARY_ID, UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS); - assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) .element(0).isEqualTo(MESSAGE1); - assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1) + assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + .hasSize(1) .element(0).isEqualTo(MESSAGE3); - assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() + assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull() .isEmpty(); } - private List load(final UUID destinationUuid, final long destinationDeviceId, + private List load(final UUID destinationUuid, final byte destinationDeviceId, final int count) { return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count)) .take(count, true) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java index 5c7f31af9..c899f0939 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java @@ -34,7 +34,7 @@ class MessagesManagerTest { final UUID destinationUuid = UUID.randomUUID(); - messagesManager.insert(destinationUuid, 1L, message); + messagesManager.insert(destinationUuid, Device.PRIMARY_ID, message); verify(reportMessageManager).store(eq(sourceAci.toString()), any(UUID.class)); @@ -42,7 +42,7 @@ class MessagesManagerTest { .setSourceUuid(destinationUuid.toString()) .build(); - messagesManager.insert(destinationUuid, 1L, syncMessage); + messagesManager.insert(destinationUuid, Device.PRIMARY_ID, syncMessage); verifyNoMoreInteractions(reportMessageManager); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java index dacbf13c3..4c30456c6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java @@ -25,7 +25,7 @@ class RefreshingAccountAndDeviceSupplierTest { final AccountsManager accountsManager = mock(AccountsManager.class); final UUID uuid = UUID.randomUUID(); - final long deviceId = 2L; + final byte deviceId = 2; final Account initialAccount = mock(Account.class); final Device initialDevice = mock(Device.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java index 57dcdfe95..c2d0d7fb4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java @@ -50,8 +50,8 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes @Test void storeIfAbsent() { final UUID identifier = UUID.randomUUID(); - final long deviceIdWithExistingKey = 1; - final long deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1; + final byte deviceIdWithExistingKey = 1; + final byte deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1; final ECSignedPreKey originalSignedPreKey = generateSignedPreKey(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 456b24450..3488b4abd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -24,11 +24,11 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { void storeFind() { final RepeatedUseSignedPreKeyStore keys = getKeyStore(); - assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join()); + assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), Device.PRIMARY_ID).join()); { final UUID identifier = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; final K signedPreKey = generateSignedPreKey(); assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); @@ -37,14 +37,15 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - 1L, generateSignedPreKey(), - 2L, generateSignedPreKey() + final byte deviceId2 = 2; + final Map signedPreKeys = Map.of( + Device.PRIMARY_ID, generateSignedPreKey(), + deviceId2, generateSignedPreKey() ); assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join()); - assertEquals(Optional.of(signedPreKeys.get(1L)), keys.find(identifier, 1).join()); - assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join()); + assertEquals(Optional.of(signedPreKeys.get(Device.PRIMARY_ID)), keys.find(identifier, Device.PRIMARY_ID).join()); + assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); } } @@ -54,32 +55,33 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); + final byte deviceId2 = 2; { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - 1L, generateSignedPreKey(), - 2L, generateSignedPreKey() + final Map signedPreKeys = Map.of( + Device.PRIMARY_ID, generateSignedPreKey(), + deviceId2, generateSignedPreKey() ); keys.store(identifier, signedPreKeys).join(); - keys.delete(identifier, 1).join(); + keys.delete(identifier, Device.PRIMARY_ID).join(); - assertEquals(Optional.empty(), keys.find(identifier, 1).join()); - assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join()); + assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); + assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); } { final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - 1L, generateSignedPreKey(), - 2L, generateSignedPreKey() + final Map signedPreKeys = Map.of( + Device.PRIMARY_ID, generateSignedPreKey(), + deviceId2, generateSignedPreKey() ); keys.store(identifier, signedPreKeys).join(); keys.delete(identifier).join(); - assertEquals(Optional.empty(), keys.find(identifier, 1).join()); - assertEquals(Optional.empty(), keys.find(identifier, 2).join()); + assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); + assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join()); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java index e284a6eb1..08693ce48 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -5,24 +5,15 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.ArrayList; -import java.util.Base64; import java.util.List; import java.util.Optional; import java.util.UUID; -import java.util.stream.Stream; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; import org.whispersystems.textsecuregcm.entities.PreKey; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.dynamodb.model.AttributeValue; abstract class SingleUsePreKeyStoreTest> { @@ -37,7 +28,7 @@ abstract class SingleUsePreKeyStoreTest> { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join()); @@ -58,7 +49,7 @@ abstract class SingleUsePreKeyStoreTest> { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); @@ -78,7 +69,7 @@ abstract class SingleUsePreKeyStoreTest> { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join()); @@ -90,12 +81,12 @@ abstract class SingleUsePreKeyStoreTest> { } preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); - preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join(); + preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join(); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); - assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); + assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join()); } @Test @@ -103,7 +94,7 @@ abstract class SingleUsePreKeyStoreTest> { final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); final UUID accountIdentifier = UUID.randomUUID(); - final long deviceId = 1; + final byte deviceId = 1; assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join()); @@ -115,11 +106,11 @@ abstract class SingleUsePreKeyStoreTest> { } preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); - preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join(); + preKeyStore.store(accountIdentifier, (byte) (deviceId + 1), preKeys).join(); assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); - assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); + assertEquals(0, preKeyStore.getCount(accountIdentifier, (byte) (deviceId + 1)).join()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index f874f7704..33b9160e3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.tests.util; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; @@ -61,9 +62,9 @@ public class AccountsHelper { return markStale ? copyAndMarkStale(account) : account; }); - when(mockAccountsManager.updateDevice(any(), anyLong(), any())).thenAnswer(answer -> { + when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> { final Account account = answer.getArgument(0, Account.class); - final Long deviceId = answer.getArgument(1, Long.class); + final byte deviceId = answer.getArgument(1, Byte.class); account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class)); return markStale ? copyAndMarkStale(account) : account; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index dee0f36bc..a47013751 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -121,12 +121,12 @@ public class AuthHelper { when(VALID_DEVICE_3_PRIMARY.isPrimary()).thenReturn(true); when(VALID_DEVICE_3_LINKED.isPrimary()).thenReturn(false); - when(VALID_DEVICE.getId()).thenReturn(1L); - when(VALID_DEVICE_TWO.getId()).thenReturn(1L); - when(DISABLED_DEVICE.getId()).thenReturn(1L); - when(UNDISCOVERABLE_DEVICE.getId()).thenReturn(1L); - when(VALID_DEVICE_3_PRIMARY.getId()).thenReturn(1L); - when(VALID_DEVICE_3_LINKED.getId()).thenReturn(2L); + when(VALID_DEVICE.getId()).thenReturn(Device.PRIMARY_ID); + when(VALID_DEVICE_TWO.getId()).thenReturn(Device.PRIMARY_ID); + when(DISABLED_DEVICE.getId()).thenReturn(Device.PRIMARY_ID); + when(UNDISCOVERABLE_DEVICE.getId()).thenReturn(Device.PRIMARY_ID); + when(VALID_DEVICE_3_PRIMARY.getId()).thenReturn(Device.PRIMARY_ID); + when(VALID_DEVICE_3_LINKED.getId()).thenReturn((byte) 2); when(VALID_DEVICE.isEnabled()).thenReturn(true); when(VALID_DEVICE_TWO.isEnabled()).thenReturn(true); @@ -135,17 +135,17 @@ public class AuthHelper { when(VALID_DEVICE_3_PRIMARY.isEnabled()).thenReturn(true); when(VALID_DEVICE_3_LINKED.isEnabled()).thenReturn(true); - when(VALID_ACCOUNT.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE)); + when(VALID_ACCOUNT.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE)); when(VALID_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE)); - when(VALID_ACCOUNT_TWO.getDevice(eq(1L))).thenReturn(Optional.of(VALID_DEVICE_TWO)); + when(VALID_ACCOUNT_TWO.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(VALID_DEVICE_TWO)); when(VALID_ACCOUNT_TWO.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO)); - when(DISABLED_ACCOUNT.getDevice(eq(1L))).thenReturn(Optional.of(DISABLED_DEVICE)); + when(DISABLED_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(DISABLED_DEVICE)); when(DISABLED_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(DISABLED_DEVICE)); - when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(1L))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); + when(UNDISCOVERABLE_ACCOUNT.getDevice(eq(Device.PRIMARY_ID))).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); when(UNDISCOVERABLE_ACCOUNT.getPrimaryDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); - when(VALID_ACCOUNT_3.getDevice(1L)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); + when(VALID_ACCOUNT_3.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); when(VALID_ACCOUNT_3.getPrimaryDevice()).thenReturn(Optional.of(VALID_DEVICE_3_PRIMARY)); - when(VALID_ACCOUNT_3.getDevice(2L)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED)); + when(VALID_ACCOUNT_3.getDevice((byte) 2)).thenReturn(Optional.of(VALID_DEVICE_3_LINKED)); when(VALID_ACCOUNT_TWO.hasEnabledLinkedDevice()).thenReturn(true); @@ -212,7 +212,7 @@ public class AuthHelper { DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter)); } - public static String getAuthHeader(UUID uuid, long deviceId, String password) { + public static String getAuthHeader(UUID uuid, byte deviceId, String password) { return HeaderUtils.basicAuthHeader(uuid.toString() + "." + deviceId, password); } @@ -260,9 +260,9 @@ public class AuthHelper { when(saltedTokenHash.verify(password)).thenReturn(true); when(device.getAuthTokenHash()).thenReturn(saltedTokenHash); when(device.isPrimary()).thenReturn(true); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(device.isEnabled()).thenReturn(true); - when(account.getDevice(1L)).thenReturn(Optional.of(device)); + when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(device)); when(account.getPrimaryDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn(number); when(account.getUuid()).thenReturn(uuid); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java index 7d9a453d9..2b877726a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/DevicesHelper.java @@ -14,15 +14,15 @@ public class DevicesHelper { private static final Random RANDOM = new Random(); - public static Device createDevice(final long deviceId) { + public static Device createDevice(final byte deviceId) { return createDevice(deviceId, 0); } - public static Device createDevice(final long deviceId, final long lastSeen) { + public static Device createDevice(final byte deviceId, final long lastSeen) { return createDevice(deviceId, lastSeen, 0); } - public static Device createDevice(final long deviceId, final long lastSeen, final int registrationId) { + public static Device createDevice(final byte deviceId, final long lastSeen, final int registrationId) { final Device device = new Device(); device.setId(deviceId); device.setLastSeen(lastSeen); @@ -34,7 +34,7 @@ public class DevicesHelper { return device; } - public static Device createDisabledDevice(final long deviceId, final int registrationId) { + public static Device createDisabledDevice(final byte deviceId, final int registrationId) { final Device device = new Device(); device.setId(deviceId); device.setUserAgent("OWT"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java index 0ff6e7856..30039d5ca 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java @@ -12,7 +12,7 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos; public class MessageHelper { - public static MessageProtos.Envelope createMessage(UUID senderUuid, final int senderDeviceId, UUID destinationUuid, + public static MessageProtos.Envelope createMessage(UUID senderUuid, final byte senderDeviceId, UUID destinationUuid, long timestamp, String content) { return MessageProtos.Envelope.newBuilder() .setServerGuid(UUID.randomUUID().toString()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java index 288b32afd..b78c6a71a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/DestinationDeviceValidatorTest.java @@ -35,7 +35,7 @@ import org.whispersystems.textsecuregcm.storage.Device; @ExtendWith(DropwizardExtensionsSupport.class) class DestinationDeviceValidatorTest { - static Account mockAccountWithDeviceAndRegId(final Map registrationIdsByDeviceId) { + static Account mockAccountWithDeviceAndRegId(final Map registrationIdsByDeviceId) { final Account account = mock(Account.class); registrationIdsByDeviceId.forEach((deviceId, registrationId) -> { @@ -48,31 +48,34 @@ class DestinationDeviceValidatorTest { } static Stream validateRegistrationIdsSource() { + final byte id1 = 1; + final byte id2 = 2; + final byte id3 = 3; return Stream.of( arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF)), - Map.of(1L, 0xFFFF, 2L, 0xDEAD, 3L, 0xBEEF), + mockAccountWithDeviceAndRegId(Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF)), + Map.of(id1, 0xFFFF, id2, 0xDEAD, id3, 0xBEEF), null), arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 42)), - Map.of(1L, 1492), - Set.of(1L)), + mockAccountWithDeviceAndRegId(Map.of(id1, 42)), + Map.of(id1, 1492), + Set.of(id1)), arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 42)), - Map.of(1L, 42), + mockAccountWithDeviceAndRegId(Map.of(id1, 42)), + Map.of(id1, 42), null), arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 42)), - Map.of(1L, 0), + mockAccountWithDeviceAndRegId(Map.of(id1, 42)), + Map.of(id1, 0), null), arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 255)), - Map.of(1L, 0, 2L, 42), - Set.of(2L)), + mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 255)), + Map.of(id1, 0, id2, 42), + Set.of(id2)), arguments( - mockAccountWithDeviceAndRegId(Map.of(1L, 42, 2L, 256)), - Map.of(1L, 41, 2L, 257), - Set.of(1L, 2L)) + mockAccountWithDeviceAndRegId(Map.of(id1, 42, id2, 256)), + Map.of(id1, 41, id2, 257), + Set.of(id1, id2)) ); } @@ -80,8 +83,8 @@ class DestinationDeviceValidatorTest { @MethodSource("validateRegistrationIdsSource") void testValidateRegistrationIds( Account account, - Map registrationIdsByDeviceId, - Set expectedStaleDeviceIds) throws Exception { + Map registrationIdsByDeviceId, + Set expectedStaleDeviceIds) throws Exception { if (expectedStaleDeviceIds != null) { Assertions.assertThat(assertThrows(StaleDevicesException.class, () -> DestinationDeviceValidator.validateRegistrationIds( @@ -98,7 +101,7 @@ class DestinationDeviceValidatorTest { } } - static Account mockAccountWithDeviceAndEnabled(final Map enabledStateByDeviceId) { + static Account mockAccountWithDeviceAndEnabled(final Map enabledStateByDeviceId) { final Account account = mock(Account.class); final List devices = new ArrayList<>(); @@ -117,51 +120,54 @@ class DestinationDeviceValidatorTest { } static Stream validateCompleteDeviceListSource() { + final byte id1 = 1; + final byte id2 = 2; + final byte id3 = 3; return Stream.of( arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(1L, 3L), + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id1, id3), null, null, Collections.emptySet()), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(1L, 2L, 3L), + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id1, id2, id3), null, - Set.of(2L), + Set.of(id2), Collections.emptySet()), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(1L), - Set.of(3L), + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id1), + Set.of(id3), null, Collections.emptySet()), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(1L, 2L), - Set.of(3L), - Set.of(2L), + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id1, id2), + Set.of(id3), + Set.of(id2), Collections.emptySet()), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(1L), - Set.of(3L), - Set.of(1L), - Set.of(1L) + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id1), + Set.of(id3), + Set.of(id1), + Set.of(id1) ), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(2L), - Set.of(3L), - Set.of(2L), - Set.of(1L) + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id2), + Set.of(id3), + Set.of(id2), + Set.of(id1) ), arguments( - mockAccountWithDeviceAndEnabled(Map.of(1L, true, 2L, false, 3L, true)), - Set.of(3L), + mockAccountWithDeviceAndEnabled(Map.of(id1, true, id2, false, id3, true)), + Set.of(id3), null, null, - Set.of(1L) + Set.of(id1) ) ); } @@ -170,10 +176,10 @@ class DestinationDeviceValidatorTest { @MethodSource("validateCompleteDeviceListSource") void testValidateCompleteDeviceList( Account account, - Set deviceIds, - Collection expectedMissingDeviceIds, - Collection expectedExtraDeviceIds, - Set excludedDeviceIds) throws Exception { + Set deviceIds, + Collection expectedMissingDeviceIds, + Collection expectedExtraDeviceIds, + Set excludedDeviceIds) throws Exception { if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 5d24ffae1..d13ddc5ee 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -103,7 +103,7 @@ class WebSocketConnectionIntegrationTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); } @AfterEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 5076b27af..86ea865e6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.nullable; @@ -162,7 +163,7 @@ class WebSocketConnectionTest { createMessage(senderOneUuid, accountUuid, 2222, "second"), createMessage(senderTwoUuid, accountUuid, 3333, "third")); - final long deviceId = 2L; + final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); @@ -178,7 +179,7 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); String userAgent = HttpHeaders.USER_AGENT; @@ -232,10 +233,10 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.empty()) .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"))) .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second"))) @@ -310,7 +311,7 @@ class WebSocketConnectionTest { final List pendingMessages = List.of(firstMessage, secondMessage); - final long deviceId = 2L; + final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); @@ -326,7 +327,7 @@ class WebSocketConnectionTest { when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty()); - when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); String userAgent = HttpHeaders.USER_AGENT; @@ -374,14 +375,14 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false); when( - messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false)) + messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false)) .thenAnswer(invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); @@ -428,7 +429,7 @@ class WebSocketConnectionTest { } }); - verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyLong(), eq(false)); + verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyByte(), eq(false)); } @Test @@ -440,7 +441,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); final UUID accountUuid = UUID.randomUUID(); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); final List firstPageMessages = @@ -450,10 +451,10 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), eq(false))) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), eq(false))) .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream()))); - when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any())) + when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -492,18 +493,18 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); final UUID accountUuid = UUID.randomUUID(); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); final UUID senderUuid = UUID.randomUUID(); final List messages = List.of( createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), Device.PRIMARY_ID, false)) .thenReturn(Flux.fromIterable(messages)) .thenReturn(Flux.empty()); - when(messagesManager.delete(eq(accountUuid), eq(1L), any(UUID.class), any())) + when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(UUID.class), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -555,10 +556,10 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -583,7 +584,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); final List firstPageMessages = @@ -593,12 +594,12 @@ class WebSocketConnectionTest { final List secondPageMessages = List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third")); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.fromIterable(firstPageMessages)) .thenReturn(Flux.fromIterable(secondPageMessages)) .thenReturn(Flux.empty()); - when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any())) + when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -640,10 +641,10 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -672,10 +673,10 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); + when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); - when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean())) + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), anyBoolean())) .thenReturn(Flux.empty()); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -695,7 +696,7 @@ class WebSocketConnectionTest { void testRetrieveMessageException() { UUID accountUuid = UUID.randomUUID(); - when(device.getId()).thenReturn(2L); + when(device.getId()).thenReturn((byte) 2); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -725,7 +726,7 @@ class WebSocketConnectionTest { void testRetrieveMessageExceptionClientDisconnected() { UUID accountUuid = UUID.randomUUID(); - when(device.getId()).thenReturn(2L); + when(device.getId()).thenReturn((byte) 2); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -748,7 +749,7 @@ class WebSocketConnectionTest { void testReactivePublisherLimitRate() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 2L; + final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); @@ -767,7 +768,7 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, @@ -798,7 +799,7 @@ class WebSocketConnectionTest { void testReactivePublisherDisposedWhenConnectionStopped() { final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 2L; + final byte deviceId = 2; when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); @@ -824,7 +825,7 @@ class WebSocketConnectionTest { final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse)); - when(messagesManager.delete(any(), anyLong(), any(), any())).thenReturn( + when(messagesManager.delete(any(), anyByte(), any(), any())).thenReturn( CompletableFuture.completedFuture(Optional.empty())); WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,