Convert Device.id from `long` to `byte`

This commit is contained in:
Chris Eager 2023-10-24 18:58:13 -05:00 committed by Chris Eager
parent 7299067829
commit 6a428b4da9
112 changed files with 1292 additions and 1094 deletions

View File

@ -236,7 +236,7 @@ public final class Operations {
return authorized(user, Device.PRIMARY_ID);
}
public RequestBuilder authorized(final TestUser user, final long deviceId) {
public RequestBuilder authorized(final TestUser user, final byte deviceId) {
final String username = "%s.%d".formatted(user.aciUuid().toString(), deviceId);
return authorized(username, user.accountPassword());
}

View File

@ -16,13 +16,13 @@ import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
public class TestDevice {
private final long deviceId;
private final byte deviceId;
private final Map<Integer, Pair<IdentityKeyPair, SignedPreKeyRecord>> signedPreKeys = new ConcurrentHashMap<>();
public static TestDevice create(
final long deviceId,
final byte deviceId,
final IdentityKeyPair aciIdentityKeyPair,
final IdentityKeyPair pniIdentityKeyPair) {
final TestDevice device = new TestDevice(deviceId);
@ -31,11 +31,11 @@ public class TestDevice {
return device;
}
public TestDevice(final long deviceId) {
public TestDevice(final byte deviceId) {
this.deviceId = deviceId;
}
public long deviceId() {
public byte deviceId() {
return deviceId;
}

View File

@ -30,7 +30,7 @@ public class TestUser {
private final IdentityKeyPair aciIdentityKey;
private final Map<Long, TestDevice> devices = new ConcurrentHashMap<>();
private final Map<Byte, TestDevice> devices = new ConcurrentHashMap<>();
private final byte[] unidentifiedAccessKey;
@ -147,7 +147,7 @@ public class TestUser {
this.registrationPassword = registrationPassword;
}
public PreKeySetPublicView preKeys(final long deviceId, final boolean pni) {
public PreKeySetPublicView preKeys(final byte deviceId, final boolean pni) {
final IdentityKeyPair identity = pni
? pniIdentityKey
: aciIdentityKey;

View File

@ -47,7 +47,7 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
}
@VisibleForTesting
static Map<Long, Boolean> buildDevicesEnabledMap(final Account account) {
static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
@ -68,17 +68,17 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
}
@Override
public List<Pair<UUID, Long>> handleRequestFinished(final RequestEvent requestEvent) {
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
// change or if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled =
(Map<Long, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> {
final Set<Long> deviceIdsToDisplace;
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());

View File

@ -52,9 +52,9 @@ public class BaseAccountAuthenticator {
this.clock = clock;
}
static Pair<String, Long> getIdentifierAndDeviceId(final String basicUsername) {
static Pair<String, Byte> getIdentifierAndDeviceId(final String basicUsername) {
final String identifier;
final long deviceId;
final byte deviceId;
final int deviceIdSeparatorIndex = basicUsername.indexOf(DEVICE_ID_SEPARATOR);
@ -63,7 +63,7 @@ public class BaseAccountAuthenticator {
deviceId = Device.PRIMARY_ID;
} else {
identifier = basicUsername.substring(0, deviceIdSeparatorIndex);
deviceId = Long.parseLong(basicUsername.substring(deviceIdSeparatorIndex + 1));
deviceId = Byte.parseByte(basicUsername.substring(deviceIdSeparatorIndex + 1));
}
return new Pair<>(identifier, deviceId);
@ -75,9 +75,9 @@ public class BaseAccountAuthenticator {
try {
final UUID accountUuid;
final long deviceId;
final byte deviceId;
{
final Pair<String, Long> identifierAndDeviceId = getIdentifierAndDeviceId(basicCredentials.getUsername());
final Pair<String, Byte> identifierAndDeviceId = getIdentifierAndDeviceId(basicCredentials.getUsername());
accountUuid = UUID.fromString(identifierAndDeviceId.first());
deviceId = identifierAndDeviceId.second();

View File

@ -11,10 +11,10 @@ import org.whispersystems.textsecuregcm.util.Pair;
public class BasicAuthorizationHeader {
private final String username;
private final long deviceId;
private final byte deviceId;
private final String password;
private BasicAuthorizationHeader(final String username, final long deviceId, final String password) {
private BasicAuthorizationHeader(final String username, final byte deviceId, final String password) {
this.username = username;
this.deviceId = deviceId;
this.password = password;
@ -59,9 +59,9 @@ public class BasicAuthorizationHeader {
final String usernameComponent = credentials.substring(0, credentialSeparatorIndex);
final String username;
final long deviceId;
final byte deviceId;
{
final Pair<String, Long> identifierAndDeviceId =
final Pair<String, Byte> identifierAndDeviceId =
BaseAccountAuthenticator.getIdentifierAndDeviceId(usernameComponent);
username = identifierAndDeviceId.first();

View File

@ -29,7 +29,7 @@ public class OptionalAccess {
verify(requestAccount, accessKey, targetAccount);
if (!deviceSelector.equals("*")) {
long deviceId = Long.parseLong(deviceSelector);
byte deviceId = Byte.parseByte(deviceSelector);
Optional<Device> targetDevice = targetAccount.get().getDevice(deviceId);

View File

@ -26,7 +26,7 @@ public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRef
}
@Override
public List<Pair<UUID, Long>> handleRequestFinished(final RequestEvent requestEvent) {
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber != null) {

View File

@ -157,7 +157,7 @@ public class RegistrationLockVerificationManager {
registrationRecoveryPasswordsManager.removeForNumber(updatedAccount.getNumber());
}
final List<Long> deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList();
final List<Byte> deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList();
clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds);
try {

View File

@ -30,5 +30,5 @@ public interface WebsocketRefreshRequirementProvider {
* @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a
* result of the observed request
*/
List<Pair<UUID, Long>> handleRequestFinished(RequestEvent requestEvent);
List<Pair<UUID, Byte>> handleRequestFinished(RequestEvent requestEvent);
}

View File

@ -7,5 +7,5 @@ package org.whispersystems.textsecuregcm.auth.grpc;
import java.util.UUID;
public record AuthenticatedDevice(UUID accountIdentifier, long deviceId) {
public record AuthenticatedDevice(UUID accountIdentifier, byte deviceId) {
}

View File

@ -17,7 +17,7 @@ import org.whispersystems.textsecuregcm.storage.Device;
public class AuthenticationUtil {
static final Context.Key<UUID> CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY = Context.key("authenticated-aci");
static final Context.Key<Long> CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
static final Context.Key<Byte> CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY = Context.key("authenticated-device-id");
/**
* Returns the account/device authenticated in the current gRPC context or throws an "unauthenticated" exception if
@ -30,7 +30,7 @@ public class AuthenticationUtil {
*/
public static AuthenticatedDevice requireAuthenticatedDevice() {
@Nullable final UUID accountIdentifier = CONTEXT_AUTHENTICATED_ACCOUNT_IDENTIFIER_KEY.get();
@Nullable final Long deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
@Nullable final Byte deviceId = CONTEXT_AUTHENTICATED_DEVICE_IDENTIFIER_KEY.get();
if (accountIdentifier != null && deviceId != null) {
return new AuthenticatedDevice(accountIdentifier, deviceId);

View File

@ -217,7 +217,7 @@ public class AccountController {
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent,
@NotNull @Valid AccountAttributes attributes) {
final Account account = disabledPermittedAuth.getAccount();
final long deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId();
final byte deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId();
final Account updatedAccount = accounts.update(account, a -> {
a.getDevice(deviceId).ifPresent(d -> {

View File

@ -135,7 +135,7 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{device_id}")
@ChangesDeviceEnabledState
public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) {
public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") byte deviceId) {
Account account = auth.getAccount();
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
@ -256,7 +256,7 @@ public class DeviceController {
@Path("/capabilities")
public void setCapabilities(@Auth AuthenticatedAccount auth, @NotNull @Valid DeviceCapabilities capabilities) {
assert (auth.getAuthenticatedDevice() != null);
final long deviceId = auth.getAuthenticatedDevice().getId();
final byte deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities));
}

View File

@ -332,7 +332,7 @@ public class KeysController {
return account.getDevices().stream().filter(Device::isEnabled).toList();
}
try {
long id = Long.parseLong(deviceId);
byte id = Byte.parseByte(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());

View File

@ -283,7 +283,7 @@ public class MessageController {
checkStoryRateLimit(destination.get(), userAgent);
}
final Set<Long> excludedDeviceIds;
final Set<Byte> excludedDeviceIds;
if (isSyncMessage) {
excludedDeviceIds = Set.of(source.get().getAuthenticatedDevice().getId());
@ -346,7 +346,7 @@ public class MessageController {
/**
* Build mapping of accounts to devices/registration IDs.
*/
private Map<Account, Set<Pair<Long, Integer>>> buildDeviceIdAndRegistrationIdMap(
private Map<Account, Set<Pair<Byte, Integer>>> buildDeviceIdAndRegistrationIdMap(
MultiRecipientMessage multiRecipientMessage,
Map<ServiceIdentifier, Account> accountsByServiceIdentifier) {
@ -403,7 +403,7 @@ public class MessageController {
checkAccessKeys(accessKeys, accountsByServiceIdentifier.values());
}
final Map<Account, Set<Pair<Long, Integer>>> accountToDeviceIdAndRegistrationIdMap =
final Map<Account, Set<Pair<Byte, Integer>>> accountToDeviceIdAndRegistrationIdMap =
buildDeviceIdAndRegistrationIdMap(multiRecipientMessage, accountsByServiceIdentifier);
// We might filter out all the recipients of a story (if none have enabled stories).
@ -420,7 +420,7 @@ public class MessageController {
checkStoryRateLimit(account, userAgent);
}
Set<Long> deviceIds = accountToDeviceIdAndRegistrationIdMap
Set<Byte> deviceIds = accountToDeviceIdAndRegistrationIdMap
.getOrDefault(account, Collections.emptySet())
.stream()
.map(Pair::first)
@ -678,7 +678,7 @@ public class MessageController {
try {
Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null);
Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
Byte sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
envelope = incomingMessage.toEnvelope(
destinationIdentifier,
sourceAccount,

View File

@ -9,19 +9,19 @@ import java.util.List;
public class MismatchedDevicesException extends Exception {
private final List<Long> missingDevices;
private final List<Long> extraDevices;
private final List<Byte> missingDevices;
private final List<Byte> extraDevices;
public MismatchedDevicesException(List<Long> missingDevices, List<Long> extraDevices) {
public MismatchedDevicesException(List<Byte> missingDevices, List<Byte> extraDevices) {
this.missingDevices = missingDevices;
this.extraDevices = extraDevices;
}
public List<Long> getMissingDevices() {
public List<Byte> getMissingDevices() {
return missingDevices;
}
public List<Long> getExtraDevices() {
public List<Byte> getExtraDevices() {
return extraDevices;
}
}

View File

@ -47,7 +47,7 @@ public class ProvisioningController {
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, (byte) 0),
Base64.getMimeDecoder().decode(message.body()))) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}

View File

@ -9,13 +9,14 @@ import java.util.List;
public class StaleDevicesException extends Exception {
private final List<Long> staleDevices;
public StaleDevicesException(List<Long> staleDevices) {
private final List<Byte> staleDevices;
public StaleDevicesException(List<Byte> staleDevices) {
this.staleDevices = staleDevices;
}
public List<Long> getStaleDevices() {
public List<Byte> getStaleDevices() {
return staleDevices;
}
}

View File

@ -98,7 +98,7 @@ public record AccountDataReportResponse(UUID reportId,
}
public record DeviceDataReport(long id,
public record DeviceDataReport(byte id,
@JsonFormat(pattern = DATE_FORMAT, timezone = UTC)
Instant lastSeen,
@JsonFormat(pattern = DATE_FORMAT, timezone = UTC)

View File

@ -54,7 +54,7 @@ public record ChangeNumberRequest(
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@NotNull @Valid Map<Long, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@NotNull @Valid Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
@ -62,10 +62,10 @@ public record ChangeNumberRequest(
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Schema(description="the new phone-number-identity registration ID for each enabled device on the account, including this one")
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@NotNull Map<Byte, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {

View File

@ -18,12 +18,12 @@ public class DeviceResponse {
private UUID pni;
@JsonProperty
private long deviceId;
private byte deviceId;
@VisibleForTesting
public DeviceResponse() {}
public DeviceResponse(UUID uuid, UUID pni, long deviceId) {
public DeviceResponse(UUID uuid, UUID pni, byte deviceId) {
this.uuid = uuid;
this.pni = pni;
this.deviceId = deviceId;
@ -37,7 +37,7 @@ public class DeviceResponse {
return pni;
}
public long getDeviceId() {
public byte getDeviceId() {
return deviceId;
}
}

View File

@ -12,11 +12,11 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
public record IncomingMessage(int type, long destinationDeviceId, int destinationRegistrationId, String content) {
public record IncomingMessage(int type, byte destinationDeviceId, int destinationRegistrationId, String content) {
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
@Nullable Account sourceAccount,
@Nullable Long sourceDeviceId,
@Nullable Byte sourceDeviceId,
final long timestamp,
final boolean story,
final boolean urgent,

View File

@ -12,9 +12,9 @@ import java.util.List;
public record MismatchedDevices(@JsonProperty
@Schema(description = "Devices present on the account but absent in the request")
List<Long> missingDevices,
List<Byte> missingDevices,
@JsonProperty
@Schema(description = "Devices absent on the request but present in the account")
List<Long> extraDevices) {
List<Byte> extraDevices) {
}

View File

@ -40,7 +40,7 @@ public record MultiRecipientMessage(
@JsonSerialize(using = ServiceIdentifierAdapter.ServiceIdentifierSerializer.class)
@JsonDeserialize(using = ServiceIdentifierAdapter.ServiceIdentifierDeserializer.class)
ServiceIdentifier uuid,
@Min(1) long deviceId,
@Min(1) byte deviceId,
@Min(0) @Max(65535) int registrationId,
@Size(min = 48, max = 48) @NotNull byte[] perRecipientKeyMaterial) {

View File

@ -22,7 +22,7 @@ public record PhoneNumberIdentityKeyDistributionRequest(
@JsonDeserialize(using = IdentityKeyAdapter.Deserializer.class)
@Schema(description="the new identity key for this account's phone-number identity")
IdentityKey pniIdentityKey,
@NotNull
@Valid
@ArraySchema(
@ -32,26 +32,26 @@ public record PhoneNumberIdentityKeyDistributionRequest(
Exactly one message must be supplied for each enabled device other than the sending (primary) device.
"""))
List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull
@Valid
@Schema(description="""
A new signed elliptic-curve prekey for each enabled device on the account, including this one.
Each must be accompanied by a valid signature from the new identity key in this request.""")
Map<Long, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
Map<Byte, @NotNull @Valid ECSignedPreKey> devicePniSignedPrekeys,
@Schema(description="""
A new signed post-quantum last-resort prekey for each enabled device on the account, including this one.
May be absent, in which case the last resort PQ prekeys for each device will be deleted if any had been stored.
If present, must contain one prekey per enabled device including this one.
Prekeys for devices that did not previously have any post-quantum prekeys stored will be silently dropped.
Each must be accompanied by a valid signature from the new identity key in this request.""")
@Valid Map<Long, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@Valid Map<Byte, @NotNull @Valid KEMSignedPreKey> devicePniPqLastResortPrekeys,
@NotNull
@Valid
@Schema(description="The new registration ID to use for the phone-number identity of each device, including this one.")
Map<Long, Integer> pniRegistrationIds) {
Map<Byte, Integer> pniRegistrationIds) {
@AssertTrue
public boolean isSignatureValidOnEachSignedPreKey() {

View File

@ -40,7 +40,7 @@ public class PreKeyResponse {
@VisibleForTesting
@JsonIgnore
public PreKeyResponseItem getDevice(int deviceId) {
public PreKeyResponseItem getDevice(byte deviceId) {
for (PreKeyResponseItem device : devices) {
if (device.getDeviceId() == deviceId) return device;
}

View File

@ -12,7 +12,7 @@ public class PreKeyResponseItem {
@JsonProperty
@Schema(description="the device ID of the device to which this item pertains")
private long deviceId;
private byte deviceId;
@JsonProperty
@Schema(description="the registration ID for the device")
@ -33,7 +33,8 @@ public class PreKeyResponseItem {
public PreKeyResponseItem() {}
public PreKeyResponseItem(long deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey, KEMSignedPreKey pqPreKey) {
public PreKeyResponseItem(byte deviceId, int registrationId, ECSignedPreKey signedPreKey, ECPreKey preKey,
KEMSignedPreKey pqPreKey) {
this.deviceId = deviceId;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
@ -62,7 +63,7 @@ public class PreKeyResponseItem {
}
@VisibleForTesting
public long getDeviceId() {
public byte getDeviceId() {
return deviceId;
}
}

View File

@ -12,5 +12,5 @@ import java.util.List;
public record StaleDevices(@JsonProperty
@Schema(description = "Devices that are no longer active")
List<Long> staleDevices) {
List<Byte> staleDevices) {
}

View File

@ -1,51 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
public class UnregisteredEvent {
@JsonProperty
@NotEmpty
private String registrationId;
@JsonProperty
private String canonicalId;
@JsonProperty
@NotEmpty
private String number;
@JsonProperty
@Min(1)
private int deviceId;
@JsonProperty
private long timestamp;
public String getRegistrationId() {
return registrationId;
}
public String getCanonicalId() {
return canonicalId;
}
public String getNumber() {
return number;
}
public int getDeviceId() {
return deviceId;
}
public long getTimestamp() {
return timestamp;
}
}

View File

@ -1,22 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.LinkedList;
import java.util.List;
public class UnregisteredEventList {
@JsonProperty
private List<UnregisteredEvent> devices;
public List<UnregisteredEvent> getDevices() {
if (devices == null) return new LinkedList<>();
else return devices;
}
}

View File

@ -0,0 +1,18 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.grpc;
import io.grpc.Status;
public class DeviceIdUtil {
static byte validate(int deviceId) {
if (deviceId > Byte.MAX_VALUE) {
throw Status.INVALID_ARGUMENT.withDescription("Device ID is out of range").asRuntimeException();
}
return (byte) deviceId;
}
}

View File

@ -78,18 +78,19 @@ public class DevicesGrpcService extends ReactorDevicesGrpc.DevicesImplBase {
if (request.getId() == Device.PRIMARY_ID) {
throw Status.INVALID_ARGUMENT.withDescription("Cannot remove primary device").asRuntimeException();
}
final byte deviceId = DeviceIdUtil.validate(request.getId());
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedPrimaryDevice();
return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier()))
.map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
.flatMap(account -> Flux.merge(
Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), request.getId())),
Mono.fromFuture(() -> keysManager.delete(account.getUuid(), request.getId())))
.then(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> a.removeDevice(request.getId()))))
Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId)),
Mono.fromFuture(() -> keysManager.delete(account.getUuid(), deviceId)))
.then(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> a.removeDevice(deviceId))))
// Some messages may have arrived while we were performing the other updates; make a best effort to clear
// those out, too
.then(Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), request.getId()))))
.then(Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId))))
.thenReturn(RemoveDeviceResponse.newBuilder().build());
}

View File

@ -39,12 +39,14 @@ public class KeysAnonymousGrpcService extends ReactorKeysAnonymousGrpc.KeysAnony
final ServiceIdentifier serviceIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getRequest().getTargetIdentifier());
final byte deviceId = DeviceIdUtil.validate(request.getRequest().getDeviceId());
return Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(serviceIdentifier))
.flatMap(Mono::justOrEmpty)
.switchIfEmpty(Mono.error(Status.UNAUTHENTICATED.asException()))
.flatMap(targetAccount ->
UnidentifiedAccessUtil.checkUnidentifiedAccess(targetAccount, request.getUnidentifiedAccessKey().toByteArray())
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), request.getRequest().getDeviceId(), keysManager)
? KeysGrpcHelper.getPreKeys(targetAccount, serviceIdentifier.identityType(), deviceId, keysManager)
: Mono.error(Status.UNAUTHENTICATED.asException()));
}

View File

@ -27,11 +27,11 @@ import reactor.util.function.Tuples;
class KeysGrpcHelper {
@VisibleForTesting
static final long ALL_DEVICES = 0;
static final byte ALL_DEVICES = 0;
static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount,
final IdentityType identityType,
final long targetDeviceId,
final byte targetDeviceId,
final KeysManager keysManager) {
final Flux<Device> devices = targetDeviceId == ALL_DEVICES
@ -73,7 +73,8 @@ class KeysGrpcHelper {
return builder;
})
.map(builder -> Tuples.of(device.getId(), builder.build()));
// Cast device IDs to `int` to match data types in the response objects protobuf definition
.map(builder -> Tuples.of((int) device.getId(), builder.build()));
})
.collectMap(Tuple2::getT1, Tuple2::getT2)
.map(preKeyBundles -> GetPreKeysResponse.newBuilder()

View File

@ -124,17 +124,19 @@ public class KeysGrpcService extends ReactorKeysGrpc.KeysImplBase {
final ServiceIdentifier targetIdentifier =
ServiceIdentifierUtil.fromGrpcServiceIdentifier(request.getTargetIdentifier());
final byte deviceId = DeviceIdUtil.validate(request.getDeviceId());
final String rateLimitKey = authenticatedDevice.accountIdentifier() + "." +
authenticatedDevice.deviceId() + "__" +
targetIdentifier.uuid() + "." +
request.getDeviceId();
deviceId;
return rateLimiters.getPreKeysLimiter().validateReactive(rateLimitKey)
.then(Mono.fromFuture(() -> accountsManager.getByServiceIdentifierAsync(targetIdentifier))
.flatMap(Mono::justOrEmpty))
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(targetAccount ->
KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), request.getDeviceId(), keysManager));
KeysGrpcHelper.getPreKeys(targetAccount, targetIdentifier.identityType(), deviceId, keysManager));
}
@Override

View File

@ -83,7 +83,14 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
MultiRecipientMessage.Recipient[] recipients = new MultiRecipientMessage.Recipient[Math.toIntExact(count)];
for (int i = 0; i < Math.toIntExact(count); i++) {
ServiceIdentifier identifier = readIdentifier(entityStream, version);
long deviceId = readVarint(entityStream);
final byte deviceId;
{
long deviceIdLong = readVarint(entityStream);
if (deviceIdLong > Byte.MAX_VALUE) {
throw new BadRequestException("Invalid device ID");
}
deviceId = (byte) deviceIdLong;
}
int registrationId = readU16(entityStream);
byte[] perRecipientKeyMaterial = entityStream.readNBytes(48);
if (perRecipientKeyMaterial.length != 48) {

View File

@ -300,7 +300,7 @@ public class ApnPushNotificationScheduler implements Managed {
}
@VisibleForTesting
static Optional<Pair<String, Long>> getSeparated(String encoded) {
static Optional<Pair<String, Byte>> getSeparated(String encoded) {
try {
if (encoded == null) return Optional.empty();
@ -311,7 +311,7 @@ public class ApnPushNotificationScheduler implements Managed {
return Optional.empty();
}
return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1])));
return Optional.of(new Pair<>(parts[0], Byte.parseByte(parts[1])));
} catch (NumberFormatException e) {
logger.warn("Badly formatted: " + encoded, e);
return Optional.empty();
@ -338,7 +338,7 @@ public class ApnPushNotificationScheduler implements Managed {
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(UUID.fromString(parts[0]));
return maybeAccount.flatMap(account -> account.getDevice(Long.parseLong(parts[1])))
return maybeAccount.flatMap(account -> account.getDevice(Byte.parseByte(parts[1])))
.map(device -> new Pair<>(maybeAccount.get(), device));
} catch (final NumberFormatException e) {

View File

@ -21,6 +21,7 @@ import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import io.lettuce.core.cluster.pubsub.RedisClusterPubSubAdapter;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
@ -34,7 +35,6 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
@ -162,7 +162,8 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
connection -> connection.sync().upstream().commands().unsubscribe(getManagerPresenceChannel(managerId)));
}
public void setPresent(final UUID accountUuid, final long deviceId, final DisplacedPresenceListener displacementListener) {
public void setPresent(final UUID accountUuid, final byte deviceId,
final DisplacedPresenceListener displacementListener) {
try (final Timer.Context ignored = setPresenceTimer.time()) {
final String presenceKey = getPresenceKey(accountUuid, deviceId);
@ -182,12 +183,12 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
}
}
public void renewPresence(final UUID accountUuid, final long deviceId) {
public void renewPresence(final UUID accountUuid, final byte deviceId) {
renewPresenceScript.execute(List.of(getPresenceKey(accountUuid, deviceId)),
List.of(managerId, String.valueOf(PRESENCE_EXPIRATION_SECONDS)));
}
public void disconnectAllPresences(final UUID accountUuid, final List<Long> deviceIds) {
public void disconnectAllPresences(final UUID accountUuid, final List<Byte> deviceIds) {
List<String> presenceKeys = new ArrayList<>();
deviceIds.forEach(deviceId -> {
@ -208,7 +209,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
disconnectAllPresences(accountUuid, Device.ALL_POSSIBLE_DEVICE_IDS);
}
public void disconnectPresence(final UUID accountUuid, final long deviceId) {
public void disconnectPresence(final UUID accountUuid, final byte deviceId) {
disconnectAllPresences(accountUuid, List.of(deviceId));
}
@ -222,18 +223,18 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
clearPresence(presenceKey);
}
public boolean isPresent(final UUID accountUuid, final long deviceId) {
public boolean isPresent(final UUID accountUuid, final byte deviceId) {
try (final Timer.Context ignored = checkPresenceTimer.time()) {
return presenceCluster.withCluster(connection ->
connection.sync().exists(getPresenceKey(accountUuid, deviceId))) == 1;
}
}
public boolean isLocallyPresent(final UUID accountUuid, final long deviceId) {
public boolean isLocallyPresent(final UUID accountUuid, final byte deviceId) {
return displacementListenersByPresenceKey.containsKey(getPresenceKey(accountUuid, deviceId));
}
public boolean clearPresence(final UUID accountUuid, final long deviceId, final DisplacedPresenceListener listener) {
public boolean clearPresence(final UUID accountUuid, final byte deviceId, final DisplacedPresenceListener listener) {
final String presenceKey = getPresenceKey(accountUuid, deviceId);
if (displacementListenersByPresenceKey.remove(presenceKey, listener)) {
return clearPresence(presenceKey);
@ -337,7 +338,7 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
}
@VisibleForTesting
static String getPresenceKey(final UUID accountUuid, final long deviceId) {
static String getPresenceKey(final UUID accountUuid, final byte deviceId) {
return "presence::{" + accountUuid.toString() + "::" + deviceId + "}";
}

View File

@ -74,7 +74,7 @@ public class PushLatencyManager {
this.clock = clock;
}
void recordPushSent(final UUID accountUuid, final long deviceId, final boolean isVoip, final boolean isUrgent) {
void recordPushSent(final UUID accountUuid, final byte deviceId, final boolean isVoip, final boolean isUrgent) {
try {
final String recordJson = SystemMapper.jsonMapper().writeValueAsString(
new PushRecord(Instant.now(clock), isVoip ? PushType.VOIP : PushType.STANDARD, Optional.of(isUrgent)));
@ -89,7 +89,7 @@ public class PushLatencyManager {
}
}
void recordQueueRead(final UUID accountUuid, final long deviceId, final String userAgentString) {
void recordQueueRead(final UUID accountUuid, final byte deviceId, final String userAgentString) {
takePushRecord(accountUuid, deviceId).thenAccept(pushRecord -> {
if (pushRecord != null) {
final Duration latency = Duration.between(pushRecord.timestamp(), Instant.now());
@ -114,7 +114,7 @@ public class PushLatencyManager {
}
@VisibleForTesting
CompletableFuture<PushRecord> takePushRecord(final UUID accountUuid, final long deviceId) {
CompletableFuture<PushRecord> takePushRecord(final UUID accountUuid, final byte deviceId) {
final String key = getFirstUnacknowledgedPushKey(accountUuid, deviceId);
return redisCluster.withCluster(connection -> {
@ -141,7 +141,7 @@ public class PushLatencyManager {
});
}
private static String getFirstUnacknowledgedPushKey(final UUID accountUuid, final long deviceId) {
private static String getFirstUnacknowledgedPushKey(final UUID accountUuid, final byte deviceId) {
return "push_latency::v2::" + accountUuid.toString() + "::" + deviceId;
}
}

View File

@ -47,7 +47,7 @@ public class PushNotificationManager {
this.pushLatencyManager = pushLatencyManager;
}
public void sendNewMessageNotification(final Account destination, final long destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
public void sendNewMessageNotification(final Account destination, final byte destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
final Device device = destination.getDevice(destinationDeviceId).orElseThrow(NotPushRegisteredException::new);
final Pair<String, PushNotification.TokenType> tokenAndType = getToken(device);

View File

@ -34,7 +34,7 @@ public class ReceiptSender {
;
}
public void sendReceipt(ServiceIdentifier sourceIdentifier, long sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) {
public void sendReceipt(ServiceIdentifier sourceIdentifier, byte sourceDeviceId, AciServiceIdentifier destinationIdentifier, long messageId) {
if (sourceIdentifier.equals(destinationIdentifier)) {
return;
}

View File

@ -223,7 +223,7 @@ public class Account {
this.devices.add(device);
}
public void removeDevice(final long deviceId) {
public void removeDevice(final byte deviceId) {
requireNotStale();
this.devices.removeIf(device -> device.getId() == deviceId);
@ -241,7 +241,7 @@ public class Account {
return getDevice(Device.PRIMARY_ID);
}
public Optional<Device> getDevice(final long deviceId) {
public Optional<Device> getDevice(final byte deviceId) {
requireNotStale();
return devices.stream().filter(device -> device.getId() == deviceId).findFirst();
@ -281,15 +281,19 @@ public class Account {
return getPrimaryDevice().map(Device::isEnabled).orElse(false);
}
public long getNextDeviceId() {
public byte getNextDeviceId() {
requireNotStale();
long candidateId = Device.PRIMARY_ID + 1;
byte candidateId = Device.PRIMARY_ID + 1;
while (getDevice(candidateId).isPresent()) {
candidateId++;
}
if (candidateId <= Device.PRIMARY_ID) {
throw new RuntimeException("device ID overflow");
}
return candidateId;
}

View File

@ -268,9 +268,9 @@ public class AccountsManager {
public Account changeNumber(final Account account,
final String targetNumber,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws InterruptedException, MismatchedDevicesException {
final String originalNumber = account.getNumber();
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
@ -369,9 +369,9 @@ public class AccountsManager {
public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey,
final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final UUID pni = account.getPhoneNumberIdentifier();
@ -395,8 +395,8 @@ public class AccountsManager {
private boolean setPniKeys(final Account account,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) {
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) {
if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
return false;
} else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) {
@ -424,9 +424,9 @@ public class AccountsManager {
}
private void validateDevices(final Account account,
@Nullable final Map<Long, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Long, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException {
@Nullable final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
@Nullable final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException {
if (pniSignedPreKeys == null && pniRegistrationIds == null) {
return;
} else if (pniSignedPreKeys == null || pniRegistrationIds == null) {
@ -580,7 +580,7 @@ public class AccountsManager {
}
/**
* Specialized version of {@link #updateDevice(Account, long, Consumer)} that minimizes potentially contentious and
* Specialized version of {@link #updateDevice(Account, byte, Consumer)} that minimizes potentially contentious and
* redundant updates of {@code device.lastSeen}
*/
public Account updateDeviceLastSeen(Account account, Device device, final long lastSeen) {
@ -741,7 +741,7 @@ public class AccountsManager {
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
public Account updateDevice(Account account, byte deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device
@ -749,7 +749,8 @@ public class AccountsManager {
});
}
public CompletableFuture<Account> updateDeviceAsync(final Account account, final long deviceId, final Consumer<Device> deviceUpdater) {
public CompletableFuture<Account> updateDeviceAsync(final Account account, final byte deviceId,
final Consumer<Device> deviceUpdater) {
return updateAsync(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device

View File

@ -43,10 +43,10 @@ public class ChangeNumberManager {
public Account changeNumber(final Account account, final String number,
@Nullable final IdentityKey pniIdentityKey,
@Nullable final Map<Long, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, KEMSignedPreKey> devicePqLastResortPreKeys,
@Nullable final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
@Nullable final List<IncomingMessage> deviceMessages,
@Nullable final Map<Long, Integer> pniRegistrationIds)
@Nullable final Map<Byte, Integer> pniRegistrationIds)
throws InterruptedException, MismatchedDevicesException, StaleDevicesException {
if (ObjectUtils.allNotNull(pniIdentityKey, deviceSignedPreKeys, deviceMessages, pniRegistrationIds)) {
@ -83,10 +83,10 @@ public class ChangeNumberManager {
public Account updatePniKeys(final Account account,
final IdentityKey pniIdentityKey,
final Map<Long, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Long, KEMSignedPreKey> devicePqLastResortPreKeys,
final Map<Byte, ECSignedPreKey> deviceSignedPreKeys,
@Nullable final Map<Byte, KEMSignedPreKey> devicePqLastResortPreKeys,
final List<IncomingMessage> deviceMessages,
final Map<Long, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
final Map<Byte, Integer> pniRegistrationIds) throws MismatchedDevicesException, StaleDevicesException {
validateDeviceMessages(account, deviceMessages);
// Don't try to be smart about ignoring unnecessary retries. If we make literally no change we will skip the ddb

View File

@ -6,11 +6,12 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
@ -19,13 +20,15 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
public class Device {
public static final long PRIMARY_ID = 1;
public static final int MAXIMUM_DEVICE_ID = 256;
public static final byte PRIMARY_ID = 1;
public static final byte MAXIMUM_DEVICE_ID = Byte.MAX_VALUE;
public static final int MAX_REGISTRATION_ID = 0x3FFF;
public static final List<Long> ALL_POSSIBLE_DEVICE_IDS = LongStream.range(1, MAXIMUM_DEVICE_ID).boxed().collect(Collectors.toList());
public static final List<Byte> ALL_POSSIBLE_DEVICE_IDS = IntStream.range(Device.PRIMARY_ID, MAXIMUM_DEVICE_ID).boxed()
.map(Integer::byteValue).collect(Collectors.toList());
@JsonDeserialize(using = DeviceIdDeserializer.class)
@JsonProperty
private long id;
private byte id;
@JsonProperty
private String name;
@ -135,11 +138,11 @@ public class Device {
}
}
public long getId() {
public byte getId() {
return id;
}
public void setId(long id) {
public void setId(byte id) {
this.id = id;
}

View File

@ -0,0 +1,41 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import java.io.IOException;
/**
* The built-in {@link com.fasterxml.jackson.databind.deser.std.NumberDeserializers.ByteDeserializer} will return
* negative values&mdash;both verbatim and by coercing 128&hellip;255. We prefer this invalid data to fail fast, so this
* is a simpler and stricter deserializer.
*/
public class DeviceIdDeserializer extends JsonDeserializer<Byte> {
@Override
public Byte deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
byte value = p.getByteValue();
if (value < Device.PRIMARY_ID) {
throw new DeviceIdDeserializationException();
}
return value;
}
static class DeviceIdDeserializationException extends IOException {
DeviceIdDeserializationException() {
super("Invalid Device ID");
}
}
}

View File

@ -42,12 +42,12 @@ public class KeysManager {
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<ECPreKey> keys) {
public CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final List<ECPreKey> keys) {
return store(identifier, deviceId, keys, null, null, null);
}
public CompletableFuture<Void> store(
final UUID identifier, final long deviceId,
final UUID identifier, final byte deviceId,
@Nullable final List<ECPreKey> ecKeys,
@Nullable final List<KEMSignedPreKey> pqKeys,
@Nullable final ECSignedPreKey ecSignedPreKey,
@ -63,7 +63,8 @@ public class KeysManager {
storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys));
}
if (ecSignedPreKey != null && dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
if (ecSignedPreKey != null && dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration()
.storeEcSignedPreKeys()) {
storeFutures.add(ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey));
}
@ -74,7 +75,7 @@ public class KeysManager {
return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
}
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Long, ECSignedPreKey> keys) {
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Byte, ECSignedPreKey> keys) {
if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
return ecSignedPreKeys.store(identifier, keys);
} else {
@ -82,27 +83,30 @@ public class KeysManager {
}
}
public CompletableFuture<Boolean> storeEcSignedPreKeyIfAbsent(final UUID identifier, final long deviceId, final ECSignedPreKey signedPreKey) {
public CompletableFuture<Boolean> storeEcSignedPreKeyIfAbsent(final UUID identifier, final byte deviceId,
final ECSignedPreKey signedPreKey) {
return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey);
}
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final Map<Long, KEMSignedPreKey> keys) {
public CompletableFuture<Void> storePqLastResort(final UUID identifier, final Map<Byte, KEMSignedPreKey> keys) {
return pqLastResortKeys.store(identifier, keys);
}
public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final long deviceId, final List<ECPreKey> preKeys) {
public CompletableFuture<Void> storeEcOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<ECPreKey> preKeys) {
return ecPreKeys.store(identifier, deviceId, preKeys);
}
public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final long deviceId, final List<KEMSignedPreKey> preKeys) {
public CompletableFuture<Void> storeKemOneTimePreKeys(final UUID identifier, final byte deviceId,
final List<KEMSignedPreKey> preKeys) {
return pqPreKeys.store(identifier, deviceId, preKeys);
}
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<ECPreKey>> takeEC(final UUID identifier, final byte deviceId) {
return ecPreKeys.take(identifier, deviceId);
}
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<KEMSignedPreKey>> takePQ(final UUID identifier, final byte deviceId) {
return pqPreKeys.take(identifier, deviceId)
.thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey
.map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey))
@ -110,26 +114,26 @@ public class KeysManager {
}
@VisibleForTesting
CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final long deviceId) {
CompletableFuture<Optional<KEMSignedPreKey>> getLastResort(final UUID identifier, final byte deviceId) {
return pqLastResortKeys.find(identifier, deviceId);
}
public CompletableFuture<Optional<ECSignedPreKey>> getEcSignedPreKey(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<ECSignedPreKey>> getEcSignedPreKey(final UUID identifier, final byte deviceId) {
return ecSignedPreKeys.find(identifier, deviceId);
}
public CompletableFuture<List<Long>> getPqEnabledDevices(final UUID identifier) {
public CompletableFuture<List<Byte>> getPqEnabledDevices(final UUID identifier) {
return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().toFuture();
}
public CompletableFuture<Integer> getEcCount(final UUID identifier, final long deviceId) {
public CompletableFuture<Integer> getEcCount(final UUID identifier, final byte deviceId) {
return ecPreKeys.getCount(identifier, deviceId);
}
public CompletableFuture<Integer> getPqCount(final UUID identifier, final long deviceId) {
public CompletableFuture<Integer> getPqCount(final UUID identifier, final byte deviceId) {
return pqPreKeys.getCount(identifier, deviceId);
}
public CompletableFuture<Void> delete(final UUID accountUuid) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid),
@ -140,7 +144,7 @@ public class KeysManager {
pqLastResortKeys.delete(accountUuid));
}
public CompletableFuture<Void> delete(final UUID accountUuid, final long deviceId) {
public CompletableFuture<Void> delete(final UUID accountUuid, final byte deviceId) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid, deviceId),
pqPreKeys.delete(accountUuid, deviceId),

View File

@ -137,7 +137,7 @@ public class MessagePersister implements Managed {
for (final String queue : queuesToPersist) {
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queue);
final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queue);
try {
persistQueue(accountUuid, deviceId);
@ -161,7 +161,7 @@ public class MessagePersister implements Managed {
}
@VisibleForTesting
void persistQueue(final UUID accountUuid, final long deviceId) throws MessagePersistenceException {
void persistQueue(final UUID accountUuid, final byte deviceId) throws MessagePersistenceException {
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(accountUuid);
if (maybeAccount.isEmpty()) {

View File

@ -155,7 +155,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
}
public long insert(final UUID guid, final UUID destinationUuid, final long destinationDevice,
public long insert(final UUID guid, final UUID destinationUuid, final byte destinationDevice,
final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
return (long) insertTimer.record(() ->
@ -168,7 +168,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
public CompletableFuture<Optional<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final long destinationDevice,
final byte destinationDevice,
final UUID messageGuid) {
return remove(destinationUuid, destinationDevice, List.of(messageGuid))
@ -177,7 +177,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
@SuppressWarnings("unchecked")
public CompletableFuture<List<MessageProtos.Envelope>> remove(final UUID destinationUuid,
final long destinationDevice,
final byte destinationDevice,
final List<UUID> messageGuids) {
return removeByGuidScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, destinationDevice),
@ -202,12 +202,12 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}, messageDeletionExecutorService);
}
public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) {
public boolean hasMessages(final UUID destinationUuid, final byte destinationDevice) {
return readDeleteCluster.withBinaryCluster(
connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0);
}
public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDevice) {
public Publisher<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDevice) {
final long earliestAllowableEphemeralTimestamp =
clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
@ -238,7 +238,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return message.hasEphemeral() && message.getEphemeral() && message.getTimestamp() < earliestAllowableTimestamp;
}
private void discardStaleEphemeralMessages(final UUID destinationUuid, final long destinationDevice,
private void discardStaleEphemeralMessages(final UUID destinationUuid, final byte destinationDevice,
Flux<MessageProtos.Envelope> staleEphemeralMessages) {
staleEphemeralMessages
.map(e -> UUID.fromString(e.getServerGuid()))
@ -251,7 +251,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
@VisibleForTesting
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final long destinationDevice) {
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final byte destinationDevice) {
// fetch messages by page
return getNextMessagePage(destinationUuid, destinationDevice, -1)
@ -284,7 +284,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
});
}
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final long destinationDevice,
private Flux<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
long messageId) {
return getItemsScript.executeBinaryReactive(
@ -315,7 +315,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
@VisibleForTesting
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final long destinationDevice,
List<MessageProtos.Envelope> getMessagesToPersist(final UUID accountUuid, final byte destinationDevice,
final int limit) {
return getMessagesTimer.record(() -> {
final List<ScoredValue<byte[]>> scoredMessages = readDeleteCluster.withBinaryCluster(
@ -336,16 +336,14 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
public CompletableFuture<Void> clear(final UUID destinationUuid) {
final CompletableFuture<?>[] clearFutures = new CompletableFuture[Device.MAXIMUM_DEVICE_ID];
for (int deviceId = 0; deviceId < Device.MAXIMUM_DEVICE_ID; deviceId++) {
clearFutures[deviceId] = clear(destinationUuid, deviceId);
}
return CompletableFuture.allOf(clearFutures);
return CompletableFuture.allOf(
Device.ALL_POSSIBLE_DEVICE_IDS.stream()
.map(deviceId -> clear(destinationUuid, deviceId))
.toList()
.toArray(CompletableFuture[]::new));
}
public CompletableFuture<Void> clear(final UUID destinationUuid, final long deviceId) {
public CompletableFuture<Void> clear(final UUID destinationUuid, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return removeQueueScript.executeBinaryAsync(List.of(getMessageQueueKey(destinationUuid, deviceId),
@ -368,23 +366,23 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
String.valueOf(limit))));
}
void addQueueToPersist(final UUID accountUuid, final long deviceId) {
void addQueueToPersist(final UUID accountUuid, final byte deviceId) {
readDeleteCluster.useBinaryCluster(connection -> connection.sync()
.zadd(getQueueIndexKey(accountUuid, deviceId), ZAddArgs.Builder.nx(), System.currentTimeMillis(),
getMessageQueueKey(accountUuid, deviceId)));
}
void lockQueueForPersistence(final UUID accountUuid, final long deviceId) {
void lockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
readDeleteCluster.useBinaryCluster(
connection -> connection.sync().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE));
}
void unlockQueueForPersistence(final UUID accountUuid, final long deviceId) {
void unlockQueueForPersistence(final UUID accountUuid, final byte deviceId) {
readDeleteCluster.useBinaryCluster(
connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId)));
}
public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId,
public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId,
final MessageAvailabilityListener listener) {
final String queueName = getQueueName(destinationUuid, deviceId);
@ -500,7 +498,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
@VisibleForTesting
static String getQueueName(final UUID accountUuid, final long deviceId) {
static String getQueueName(final UUID accountUuid, final byte deviceId) {
return accountUuid + "::" + deviceId;
}
@ -513,15 +511,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
}
@VisibleForTesting
static byte[] getMessageQueueKey(final UUID accountUuid, final long deviceId) {
static byte[] getMessageQueueKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final long deviceId) {
private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getQueueIndexKey(final UUID accountUuid, final long deviceId) {
private static byte[] getQueueIndexKey(final UUID accountUuid, final byte deviceId) {
return getQueueIndexKey(SlotHash.getSlot(accountUuid.toString() + "::" + deviceId));
}
@ -529,7 +527,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getPersistInProgressKey(final UUID accountUuid, final long deviceId) {
private static byte[] getPersistInProgressKey(final UUID accountUuid, final byte deviceId) {
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
@ -539,7 +537,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
return UUID.fromString(queueName.substring(startOfHashTag + 1, queueName.indexOf("::", startOfHashTag)));
}
static long getDeviceIdFromQueueName(final String queueName) {
return Long.parseLong(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
static byte getDeviceIdFromQueueName(final String queueName) {
return Byte.parseByte(queueName.substring(queueName.lastIndexOf("::") + 2, queueName.lastIndexOf('}')));
}
}

View File

@ -83,11 +83,13 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
this.messageDeletionScheduler = Schedulers.fromExecutor(messageDeletionExecutor);
}
public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) {
public void store(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid,
final byte destinationDeviceId) {
storeTimer.record(() -> writeInBatches(messages, (messageBatch) -> storeBatch(messageBatch, destinationAccountUuid, destinationDeviceId)));
}
private void storeBatch(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid, final long destinationDeviceId) {
private void storeBatch(final List<MessageProtos.Envelope> messages, final UUID destinationAccountUuid,
final byte destinationDeviceId) {
if (messages.size() > DYNAMO_DB_MAX_BATCH_SIZE) {
throw new IllegalArgumentException("Maximum batch size of " + DYNAMO_DB_MAX_BATCH_SIZE + " exceeded with " + messages.size() + " messages");
}
@ -112,7 +114,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
executeTableWriteItemsUntilComplete(Map.of(tableName, writeItems));
}
public Publisher<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final long destinationDeviceId,
public Publisher<MessageProtos.Envelope> load(final UUID destinationAccountUuid, final byte destinationDeviceId,
final Integer limit) {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@ -191,7 +193,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
}
public CompletableFuture<Optional<MessageProtos.Envelope>> deleteMessage(final UUID destinationAccountUuid,
final long destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
final byte destinationDeviceId, final UUID messageUuid, final long serverTimestamp) {
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
final AttributeValue sortKey = convertSortKey(destinationDeviceId, serverTimestamp, messageUuid);
@ -240,7 +242,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
.toFuture();
}
public CompletableFuture<Void> deleteAllMessagesForDevice(final UUID destinationAccountUuid, final long destinationDeviceId) {
public CompletableFuture<Void> deleteAllMessagesForDevice(final UUID destinationAccountUuid,
final byte destinationDeviceId) {
final Timer.Sample sample = Timer.start();
final AttributeValue partitionKey = convertPartitionKey(destinationAccountUuid);
@ -284,8 +287,10 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
return AttributeValues.fromUUID(destinationAccountUuid);
}
private static AttributeValue convertSortKey(final long destinationDeviceId, final long serverTimestamp, final UUID messageUuid) {
private static AttributeValue convertSortKey(final byte destinationDeviceId, final long serverTimestamp,
final UUID messageUuid) {
ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[32]);
// for compatibility - destinationDeviceId was previously `long`
byteBuffer.putLong(destinationDeviceId);
byteBuffer.putLong(serverTimestamp);
byteBuffer.putLong(messageUuid.getMostSignificantBits());
@ -293,8 +298,9 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final long destinationDeviceId) {
private static AttributeValue convertDestinationDeviceIdToSortKeyPrefix(final byte destinationDeviceId) {
ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
// for compatibility - destinationDeviceId was previously `long`
byteBuffer.putLong(destinationDeviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}

View File

@ -60,7 +60,7 @@ public class MessagesManager {
this.messageDeletionExecutor = messageDeletionExecutor;
}
public void insert(UUID destinationUuid, long destinationDevice, Envelope message) {
public void insert(UUID destinationUuid, byte destinationDevice, Envelope message) {
final UUID messageGuid = UUID.randomUUID();
messagesCache.insert(messageGuid, destinationUuid, destinationDevice, message);
@ -70,11 +70,11 @@ public class MessagesManager {
}
}
public boolean hasCachedMessages(final UUID destinationUuid, final long destinationDevice) {
public boolean hasCachedMessages(final UUID destinationUuid, final byte destinationDevice) {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public Mono<Pair<List<Envelope>, Boolean>> getMessagesForDevice(UUID destinationUuid, long destinationDevice,
public Mono<Pair<List<Envelope>, Boolean>> getMessagesForDevice(UUID destinationUuid, byte destinationDevice,
boolean cachedMessagesOnly) {
return Flux.from(
@ -84,13 +84,13 @@ public class MessagesManager {
.map(envelopes -> new Pair<>(envelopes, envelopes.size() >= RESULT_SET_CHUNK_SIZE));
}
public Publisher<Envelope> getMessagesForDeviceReactive(UUID destinationUuid, long destinationDevice,
public Publisher<Envelope> getMessagesForDeviceReactive(UUID destinationUuid, byte destinationDevice,
final boolean cachedMessagesOnly) {
return getMessagesForDevice(destinationUuid, destinationDevice, null, cachedMessagesOnly);
}
private Publisher<Envelope> getMessagesForDevice(UUID destinationUuid, long destinationDevice,
private Publisher<Envelope> getMessagesForDevice(UUID destinationUuid, byte destinationDevice,
@Nullable Integer limit, final boolean cachedMessagesOnly) {
final Publisher<Envelope> dynamoPublisher =
@ -108,13 +108,13 @@ public class MessagesManager {
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid));
}
public CompletableFuture<Void> clear(UUID destinationUuid, long deviceId) {
public CompletableFuture<Void> clear(UUID destinationUuid, byte deviceId) {
return CompletableFuture.allOf(
messagesCache.clear(destinationUuid, deviceId),
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId));
}
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, long destinationDeviceId, UUID guid,
public CompletableFuture<Optional<Envelope>> delete(UUID destinationUuid, byte destinationDeviceId, UUID guid,
@Nullable Long serverTimestamp) {
return messagesCache.remove(destinationUuid, destinationDeviceId, guid)
.thenComposeAsync(removed -> {
@ -140,7 +140,7 @@ public class MessagesManager {
*/
public int persistMessages(
final UUID destinationUuid,
final long destinationDeviceId,
final byte destinationDeviceId,
final List<Envelope> messages) {
final List<Envelope> nonEphemeralMessages = messages.stream()
@ -165,7 +165,7 @@ public class MessagesManager {
public void addMessageAvailabilityListener(
final UUID destinationUuid,
final long destinationDeviceId,
final byte destinationDeviceId,
final MessageAvailabilityListener listener) {
messagesCache.addMessageAvailabilityListener(destinationUuid, destinationDeviceId, listener);
}

View File

@ -14,7 +14,7 @@ public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, long deviceId, AccountsManager accountsManager) {
public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));

View File

@ -31,7 +31,7 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final ECSignedPreKey signedPreKey) {
protected Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final byte deviceId, final ECSignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),
@ -54,7 +54,7 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore
}
}
public CompletableFuture<Boolean> storeIfAbsent(final UUID identifier, final long deviceId, final ECSignedPreKey signedPreKey) {
public CompletableFuture<Boolean> storeIfAbsent(final UUID identifier, final byte deviceId, final ECSignedPreKey signedPreKey) {
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, signedPreKey))

View File

@ -21,7 +21,7 @@ public class RepeatedUseKEMSignedPreKeyStore extends RepeatedUseSignedPreKeyStor
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final KEMSignedPreKey signedPreKey) {
protected Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final byte deviceId, final KEMSignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(accountUuid),

View File

@ -67,7 +67,7 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
*
* @return a future that completes once the key has been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final K signedPreKey) {
public CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final K signedPreKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
@ -87,13 +87,13 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
*
* @return a future that completes once all keys have been stored
*/
public CompletableFuture<Void> store(final UUID identifier, final Map<Long, K> signedPreKeysByDeviceId) {
public CompletableFuture<Void> store(final UUID identifier, final Map<Byte, K> signedPreKeysByDeviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder()
.transactItems(signedPreKeysByDeviceId.entrySet().stream()
.map(entry -> {
final long deviceId = entry.getKey();
final byte deviceId = entry.getKey();
final K signedPreKey = entry.getValue();
return TransactWriteItem.builder()
@ -117,7 +117,7 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
* @return a future that yields an optional signed pre-key if one is available for the target device or empty if no
* key could be found for the target device
*/
public CompletableFuture<Optional<K>> find(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<K>> find(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
final CompletableFuture<Optional<K>> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder()
@ -165,7 +165,7 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
*
* @return a future that completes once the repeated-use pre-key has been removed from the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
public CompletableFuture<Void> delete(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder()
@ -175,7 +175,7 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
.thenRun(() -> sample.stop(deleteForDeviceTimer));
}
public Flux<Long> getDeviceIdsWithKeys(final UUID identifier) {
public Flux<Byte> getDeviceIdsWithKeys(final UUID identifier) {
return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
.tableName(tableName)
.keyConditionExpression("#uuid = :uuid")
@ -186,10 +186,10 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
.consistentRead(true)
.build())
.items())
.map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n()));
.map(item -> Byte.parseByte(item.get(KEY_DEVICE_ID).n()));
}
protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final long deviceId) {
protected static Map<String, AttributeValue> getPrimaryKey(final UUID identifier, final byte deviceId) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID, getSortKey(deviceId));
@ -199,11 +199,12 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
return AttributeValues.fromUUID(accountUuid);
}
protected static AttributeValue getSortKey(final long deviceId) {
return AttributeValues.fromLong(deviceId);
protected static AttributeValue getSortKey(final byte deviceId) {
return AttributeValues.fromInt(deviceId);
}
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final long deviceId, final K signedPreKey);
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID accountUuid, final byte deviceId,
final K signedPreKey);
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);
}

View File

@ -24,7 +24,7 @@ public class SingleUseECPreKeyStore extends SingleUsePreKeyStore<ECPreKey> {
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final ECPreKey preKey) {
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId, final ECPreKey preKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()),

View File

@ -21,7 +21,7 @@ public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore<KEMSignedPreKe
}
@Override
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId, final KEMSignedPreKey signedPreKey) {
protected Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId, final KEMSignedPreKey signedPreKey) {
return Map.of(
KEY_ACCOUNT_UUID, getPartitionKey(identifier),
KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.keyId()),

View File

@ -36,11 +36,11 @@ import software.amazon.awssdk.services.dynamodb.model.Select;
/**
* A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key
* store's {@link #take(UUID, long)} method are guaranteed to be returned exactly once, and repeated calls will never
* store's {@link #take(UUID, byte)} method are guaranteed to be returned exactly once, and repeated calls will never
* yield the same key.
* <p/>
* Each {@link Account} may have one or more {@link Device devices}. Clients <em>should</em> regularly check their
* supply of single-use pre-keys (see {@link #getCount(UUID, long)}) and upload new keys when their supply runs low. In
* supply of single-use pre-keys (see {@link #getCount(UUID, byte)}) and upload new keys when their supply runs low. In
* the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party
* may fall back to using the device's repeated-use ("last-resort") signed pre-key instead.
*/
@ -91,7 +91,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
* @return a future that completes when all previously-stored keys have been removed and the given collection of
* pre-keys has been stored in its place
*/
public CompletableFuture<Void> store(final UUID identifier, final long deviceId, final List<K> preKeys) {
public CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final List<K> preKeys) {
final Timer.Sample sample = Timer.start();
return Mono.fromFuture(() -> delete(identifier, deviceId))
@ -103,7 +103,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
.thenRun(() -> sample.stop(storeKeyBatchTimer));
}
private CompletableFuture<Void> store(final UUID identifier, final long deviceId, final K preKey) {
private CompletableFuture<Void> store(final UUID identifier, final byte deviceId, final K preKey) {
final Timer.Sample sample = Timer.start();
return dynamoDbAsyncClient.putItem(PutItemRequest.builder()
@ -124,7 +124,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
* @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are
* available for the target device
*/
public CompletableFuture<Optional<K>> take(final UUID identifier, final long deviceId) {
public CompletableFuture<Optional<K>> take(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
final AttributeValue partitionKey = getPartitionKey(identifier);
final AtomicInteger keysConsidered = new AtomicInteger(0);
@ -169,7 +169,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
* @return a future that yields the approximate number of single-use pre-keys currently available for the target
* device
*/
public CompletableFuture<Integer> getCount(final UUID identifier, final long deviceId) {
public CompletableFuture<Integer> getCount(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
// Getting an accurate count from DynamoDB can be very confusing. See:
@ -230,7 +230,7 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
* @return a future that completes when all single-use pre-keys have been removed for the target device
*/
public CompletableFuture<Void> delete(final UUID identifier, final long deviceId) {
public CompletableFuture<Void> delete(final UUID identifier, final byte deviceId) {
final Timer.Sample sample = Timer.start();
return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder()
@ -267,20 +267,20 @@ public abstract class SingleUsePreKeyStore<K extends PreKey<?>> {
return AttributeValues.fromUUID(accountUuid);
}
protected static AttributeValue getSortKey(final long deviceId, final long keyId) {
protected static AttributeValue getSortKey(final byte deviceId, final long keyId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
byteBuffer.putLong(deviceId);
byteBuffer.putLong(keyId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private static AttributeValue getSortKeyPrefix(final long deviceId) {
private static AttributeValue getSortKeyPrefix(final byte deviceId) {
final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]);
byteBuffer.putLong(deviceId);
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final long deviceId,
protected abstract Map<String, AttributeValue> getItemFromPreKey(final UUID identifier, final byte deviceId,
final K preKey);
protected abstract K getPreKeyFromItem(final Map<String, AttributeValue> item);

View File

@ -8,7 +8,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
@ -24,7 +23,7 @@ public class DestinationDeviceValidator {
* @see #validateRegistrationIds(Account, Stream, boolean)
*/
public static <T> void validateRegistrationIds(final Account account, final Collection<T> messages,
Function<T, Long> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity)
Function<T, Byte> getDeviceId, Function<T, Integer> getRegistrationId, boolean usePhoneNumberIdentity)
throws StaleDevicesException {
validateRegistrationIds(account,
messages.stream().map(m -> new Pair<>(getDeviceId.apply(m), getRegistrationId.apply(m))),
@ -47,13 +46,13 @@ public class DestinationDeviceValidator {
* account does not have a corresponding device or if the registration IDs do not match
*/
public static void validateRegistrationIds(final Account account,
final Stream<Pair<Long, Integer>> deviceIdAndRegistrationIdStream,
final Stream<Pair<Byte, Integer>> deviceIdAndRegistrationIdStream,
final boolean usePhoneNumberIdentity) throws StaleDevicesException {
final List<Long> staleDevices = deviceIdAndRegistrationIdStream
final List<Byte> staleDevices = deviceIdAndRegistrationIdStream
.filter(deviceIdAndRegistrationId -> deviceIdAndRegistrationId.second() > 0)
.filter(deviceIdAndRegistrationId -> {
final long deviceId = deviceIdAndRegistrationId.first();
final byte deviceId = deviceIdAndRegistrationId.first();
final int registrationId = deviceIdAndRegistrationId.second();
boolean registrationIdMatches = account.getDevice(deviceId)
.map(device -> registrationId == (usePhoneNumberIdentity
@ -86,19 +85,19 @@ public class DestinationDeviceValidator {
* account
*/
public static void validateCompleteDeviceList(final Account account,
final Set<Long> messageDeviceIds,
final Set<Long> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Byte> messageDeviceIds,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {
final Set<Long> accountDeviceIds = account.getDevices().stream()
final Set<Byte> accountDeviceIds = account.getDevices().stream()
.filter(Device::isEnabled)
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
final Set<Long> missingDeviceIds = new HashSet<>(accountDeviceIds);
final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);
final Set<Long> extraDeviceIds = new HashSet<>(messageDeviceIds);
final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);
if (!missingDeviceIds.isEmpty() || !extraDeviceIds.isEmpty()) {

View File

@ -10,7 +10,7 @@ import java.util.Base64;
public class ProvisioningAddress extends WebsocketAddress {
public ProvisioningAddress(String address, int id) {
public ProvisioningAddress(String address, byte id) {
super(address, id);
}
@ -26,6 +26,6 @@ public class ProvisioningAddress extends WebsocketAddress {
byte[] random = new byte[16];
new SecureRandom().nextBytes(random);
return new ProvisioningAddress(Base64.getUrlEncoder().withoutPadding().encodeToString(random), 0);
return new ProvisioningAddress(Base64.getUrlEncoder().withoutPadding().encodeToString(random), (byte) 0);
}
}

View File

@ -10,9 +10,9 @@ import org.whispersystems.textsecuregcm.storage.PubSubAddress;
public class WebsocketAddress implements PubSubAddress {
private final String number;
private final long deviceId;
private final byte deviceId;
public WebsocketAddress(String number, long deviceId) {
public WebsocketAddress(String number, byte deviceId) {
this.number = number;
this.deviceId = deviceId;
}
@ -26,7 +26,7 @@ public class WebsocketAddress implements PubSubAddress {
}
this.number = parts[0];
this.deviceId = Long.parseLong(parts[1]);
this.deviceId = Byte.parseByte(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
@ -36,7 +36,7 @@ public class WebsocketAddress implements PubSubAddress {
return number;
}
public long getDeviceId() {
public byte getDeviceId() {
return deviceId;
}

View File

@ -41,7 +41,7 @@ public class MigrateSignedECPreKeysCommand extends AbstractSinglePassCrawlAccoun
accounts.flatMap(account -> Flux.fromIterable(account.getDevices())
.flatMap(device -> {
final List<Tuple3<UUID, Long, ECSignedPreKey>> keys = new ArrayList<>(2);
final List<Tuple3<UUID, Byte, ECSignedPreKey>> keys = new ArrayList<>(2);
if (device.getSignedPreKey(IdentityType.ACI) != null) {
keys.add(Tuples.of(account.getUuid(), device.getId(), device.getSignedPreKey(IdentityType.ACI)));

View File

@ -36,7 +36,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
subparser.addArgument("-d", "--deviceId")
.dest("deviceIds")
.type(Long.class)
.type(Byte.class)
.action(Arguments.append())
.required(true);
@ -57,7 +57,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
commandStopListener.start();
final UUID aci = UUID.fromString(namespace.getString("uuid").trim());
final List<Long> deviceIds = namespace.getList("deviceIds");
final List<Byte> deviceIds = namespace.getList("deviceIds");
final CommandDependencies deps = CommandDependencies.build("unlink-device", environment, configuration);
@ -68,7 +68,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
throw new IllegalArgumentException("cannot delete primary device");
}
for (long deviceId : deviceIds) {
for (byte deviceId : deviceIds) {
/** see {@link org.whispersystems.textsecuregcm.controllers.DeviceController#removeDevice} */
System.out.format("Removing device %s::%d\n", aci, deviceId);
account = deps.accountsManager().update(account, a -> a.removeDevice(deviceId));

View File

@ -55,7 +55,7 @@ message GetDevicesResponse {
/**
* The identifier for the device within an account.
*/
uint64 id = 1;
uint32 id = 1;
/**
* A sequence of bytes that encodes an encrypted human-readable name for
@ -86,7 +86,7 @@ message RemoveDeviceRequest {
/**
* The identifier for the device to remove from the authenticated account.
*/
uint64 id = 1;
uint32 id = 1;
}
message SetDeviceNameRequest {

View File

@ -154,7 +154,7 @@ message GetPreKeysRequest {
* retrieve pre-keys. If not set, pre-keys are returned for all devices
* associated with the targeted account.
*/
uint64 device_id = 2;
uint32 device_id = 2;
}
message GetPreKeysAnonymousRequest {
@ -199,7 +199,7 @@ message GetPreKeysResponse {
/**
* A map of device IDs to pre-key "bundles" for the targeted account.
*/
map<uint64, PreKeyBundle> pre_keys = 2;
map<uint32, PreKeyBundle> pre_keys = 2;
}
message SetOneTimeEcPreKeysRequest {
@ -276,4 +276,3 @@ message CheckIdentityKeyResponse {
*/
bytes identity_key = 2;
}

View File

@ -43,7 +43,7 @@ import java.util.Set;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
@ -89,7 +89,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
@ -126,7 +126,8 @@ class AuthEnablementRefreshRequirementProviderTest {
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
IntStream.range(2, 4)
.forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId)));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
@ -137,22 +138,22 @@ class AuthEnablementRefreshRequirementProviderTest {
@Test
void testBuildDevicesEnabled() {
final long disabledDeviceId = 3L;
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
LongStream.range(1, 5)
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(id);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Long, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
@ -168,7 +169,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();
assert account.getPrimaryDevice().orElseThrow().isEnabled();
@ -199,13 +200,16 @@ class AuthEnablementRefreshRequirementProviderTest {
}
static Stream<Arguments> testDeviceEnabledChanged() {
final byte deviceId1 = Device.PRIMARY_ID;
final byte deviceId2 = 2;
final byte deviceId3 = 3;
return Stream.of(
Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, false, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, false), Map.of(2L, true, 3L, true))
Arguments.of(Map.of(deviceId1, false, deviceId2, false), Map.of(deviceId1, true, deviceId2, false)),
Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, false, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, false), Map.of(deviceId2, true, deviceId3, true))
);
}
@ -227,9 +231,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), 3);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 1);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 2);
verify(clientPresenceManager).disconnectPresence(account.getUuid(), (byte) 3);
}
@ParameterizedTest
@ -237,13 +241,13 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
final List<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Long> deletedDeviceIds = account.getDevices().stream()
final List<Byte> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != 1L)
.filter(deviceId -> deviceId != Device.PRIMARY_ID)
.limit(removedDeviceCount)
.collect(Collectors.toList());
.toList();
assert deletedDeviceIds.size() == removedDeviceCount;
@ -269,9 +273,9 @@ class AuthEnablementRefreshRequirementProviderTest {
void testPrimaryDeviceDisabledAndDeviceRemoved() {
assert account.getPrimaryDevice().orElseThrow().isEnabled();
final Set<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final Set<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final long deletedDeviceId = 2L;
final byte deletedDeviceId = 2;
assertTrue(initialDeviceIds.remove(deletedDeviceId));
final Response response = resources.getJerseyTest()
@ -427,11 +431,11 @@ class AuthEnablementRefreshRequirementProviderTest {
@POST
@Path("/account/devices/enabled")
@ChangesDeviceEnabledState
public String setEnabled(@Auth TestPrincipal principal, Map<Long, Boolean> deviceIdsEnabled) {
public String setEnabled(@Auth TestPrincipal principal, Map<Byte, Boolean> deviceIdsEnabled) {
final StringBuilder response = new StringBuilder();
for (Entry<Long, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
for (Entry<Byte, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow();
DevicesHelper.setEnabled(device, deviceIdEnabled.getValue());
@ -462,7 +466,7 @@ class AuthEnablementRefreshRequirementProviderTest {
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Long::valueOf)
.map(Byte::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
@ -471,7 +475,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@POST
@Path("/account/disablePrimaryDeviceAndDeleteDevice/{deviceId}")
@ChangesDeviceEnabledState
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) {
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") byte deviceId) {
DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice().orElseThrow(), false);

View File

@ -150,7 +150,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticate() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@ -180,7 +180,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateNonDefaultDevice() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 2;
final byte deviceId = 2;
final String password = "12345";
final Account account = mock(Account.class);
@ -214,7 +214,7 @@ class BaseAccountAuthenticatorTest {
@CartesianTest.Values(booleans = {true, false}) final boolean deviceEnabled,
@CartesianTest.Values(booleans = {true, false}) final boolean authenticatedDeviceIsPrimary) {
final UUID uuid = UUID.randomUUID();
final long deviceId = authenticatedDeviceIsPrimary ? 1 : 2;
final byte deviceId = (byte) (authenticatedDeviceIsPrimary ? 1 : 2);
final String password = "12345";
final Account account = mock(Account.class);
@ -253,7 +253,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateV1() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@ -290,7 +290,7 @@ class BaseAccountAuthenticatorTest {
@Test
void testAuthenticateDeviceNotFound() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@ -312,13 +312,13 @@ class BaseAccountAuthenticatorTest {
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid + "." + (deviceId + 1), password), true);
assertThat(maybeAuthenticatedAccount).isEmpty();
verify(account).getDevice(deviceId + 1);
verify(account).getDevice((byte) (deviceId + 1));
}
@Test
void testAuthenticateIncorrectPassword() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
@ -365,8 +365,9 @@ class BaseAccountAuthenticatorTest {
@ParameterizedTest
@MethodSource
void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier, final long expectedDeviceId) {
final Pair<String, Long> identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username);
void testGetIdentifierAndDeviceId(final String username, final String expectedIdentifier,
final byte expectedDeviceId) {
final Pair<String, Byte> identifierAndDeviceId = BaseAccountAuthenticator.getIdentifierAndDeviceId(username);
assertEquals(expectedIdentifier, identifierAndDeviceId.first());
assertEquals(expectedDeviceId, identifierAndDeviceId.second());
@ -376,7 +377,7 @@ class BaseAccountAuthenticatorTest {
return Stream.of(
Arguments.of("", "", Device.PRIMARY_ID),
Arguments.of("test", "test", Device.PRIMARY_ID),
Arguments.of("test.7", "test", 7));
Arguments.of("test.7", "test", (byte) 7));
}
@ParameterizedTest

View File

@ -34,11 +34,11 @@ class CertificateGeneratorTest {
final CertificateGenerator certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(SIGNING_CERTIFICATE), Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn(4L);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn((byte) 4);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
}
}

View File

@ -32,7 +32,7 @@ class OptionalAccessTest {
void testUnidentifiedMissingTargetDevice() {
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq(10))).thenReturn(Optional.empty());
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try {
@ -46,7 +46,7 @@ class OptionalAccessTest {
void testUnidentifiedBadTargetDevice() {
Account account = mock(Account.class);
when(account.isEnabled()).thenReturn(true);
when(account.getDevice(eq(10))).thenReturn(Optional.empty());
when(account.getDevice(eq((byte) 10))).thenReturn(Optional.empty());
when(account.getUnidentifiedAccessKey()).thenReturn(Optional.of("1234".getBytes()));
try {

View File

@ -18,9 +18,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
public class MockAuthenticationInterceptor implements ServerInterceptor {
@Nullable
private Pair<UUID, Long> authenticatedDevice;
private Pair<UUID, Byte> authenticatedDevice;
public void setAuthenticatedDevice(final UUID accountIdentifier, final long deviceId) {
public void setAuthenticatedDevice(final UUID accountIdentifier, final byte deviceId) {
authenticatedDevice = new Pair<>(accountIdentifier, deviceId);
}

View File

@ -10,8 +10,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
@ -299,7 +299,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@Test
@ -328,7 +328,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second"));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@Test
@ -344,7 +344,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null);
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyByte(), any());
}
@ParameterizedTest

View File

@ -160,7 +160,7 @@ class AccountControllerV2Test {
}
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@ -481,7 +481,7 @@ class AccountControllerV2Test {
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@ -661,7 +661,7 @@ class AccountControllerV2Test {
assertEquals(account.isUnrestrictedUnidentifiedAccess(),
structuredResponse.data().account().allowSealedSenderFromAnyone());
final Set<Long> deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final Set<Byte> deviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
// all devices should be present
structuredResponse.data().devices().forEach(deviceDataReport -> {
@ -704,8 +704,8 @@ class AccountControllerV2Test {
buildTestAccountForDataReport(UUID.randomUUID(), exampleNumber1,
true, true,
Collections.emptyList(),
List.of(new DeviceData(1, account1Device1LastSeen, account1Device1Created, null),
new DeviceData(2, account1Device2LastSeen, account1Device2Created, "OWP"))),
List.of(new DeviceData(Device.PRIMARY_ID, account1Device1LastSeen, account1Device1Created, null),
new DeviceData((byte) 2, account1Device2LastSeen, account1Device2Created, "OWP"))),
String.format("""
# Account
Phone number: %s
@ -730,7 +730,7 @@ class AccountControllerV2Test {
buildTestAccountForDataReport(UUID.randomUUID(), account2PhoneNumber,
false, true,
List.of(new AccountBadge("badge_a", badgeAExpiration, true)),
List.of(new DeviceData(1, account2Device1LastSeen, account2Device1Created, "OWI"))),
List.of(new DeviceData(Device.PRIMARY_ID, account2Device1LastSeen, account2Device1Created, "OWI"))),
String.format("""
# Account
Phone number: %s
@ -756,7 +756,7 @@ class AccountControllerV2Test {
List.of(
new AccountBadge("badge_b", badgeBExpiration, true),
new AccountBadge("badge_c", badgeCExpiration, false)),
List.of(new DeviceData(1, account3Device1LastSeen, account3Device1Created, "OWA"))),
List.of(new DeviceData(Device.PRIMARY_ID, account3Device1LastSeen, account3Device1Created, "OWA"))),
String.format("""
# Account
Phone number: %s
@ -825,7 +825,7 @@ class AccountControllerV2Test {
return account;
}
private record DeviceData(long id, Instant lastSeen, Instant created, @Nullable String userAgent) {
private record DeviceData(byte id, Instant lastSeen, Instant created, @Nullable String userAgent) {
}

View File

@ -8,7 +8,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
@ -99,6 +99,8 @@ class DeviceControllerTest {
private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static TestClock testClock = TestClock.now();
private static final byte NEXT_DEVICE_ID = 42;
private static DeviceController deviceController = new DeviceController(
generateLinkDeviceSecret(),
accountsManager,
@ -137,9 +139,9 @@ class DeviceControllerTest {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(1L);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(account.getNextDeviceId()).thenReturn(42L);
when(account.getNextDeviceId()).thenReturn(NEXT_DEVICE_ID);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
@ -154,9 +156,9 @@ class DeviceControllerTest {
AccountsHelper.setupMockUpdate(accountsManager);
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
}
@AfterEach
@ -199,9 +201,9 @@ class DeviceControllerTest {
MediaType.APPLICATION_JSON_TYPE),
DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(commands).set(anyString(), anyString(), any());
}
@ -315,7 +317,7 @@ class DeviceControllerTest {
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE), DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L);
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class);
verify(account).addDevice(deviceCaptor.capture());
@ -335,7 +337,7 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey.get()));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey.get()));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get()));
@ -751,7 +753,7 @@ class DeviceControllerTest {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
final long deviceId = 2;
final byte deviceId = 2;
final Response response = resources
.getJerseyTest()
@ -785,10 +787,10 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(403);
verify(messagesManager, never()).clear(any(), anyLong());
verify(messagesManager, never()).clear(any(), anyByte());
verify(accountsManager, never()).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyLong());
verify(keysManager, never()).delete(any(), anyLong());
verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyByte());
verify(keysManager, never()).delete(any(), anyByte());
}
}

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.clearInvocations;
@ -84,6 +84,11 @@ class KeysControllerTest {
private static final UUID NOT_EXISTS_UUID = UUID.randomUUID();
private static final byte SAMPLE_DEVICE_ID = 1;
private static final byte SAMPLE_DEVICE_ID2 = 2;
private static final byte SAMPLE_DEVICE_ID3 = 3;
private static final byte SAMPLE_DEVICE_ID4 = 4;
private static final int SAMPLE_REGISTRATION_ID = 999;
private static final int SAMPLE_REGISTRATION_ID2 = 1002;
private static final int SAMPLE_REGISTRATION_ID4 = 1555;
@ -180,6 +185,11 @@ class KeysControllerTest {
final List<Device> allDevices = List.of(sampleDevice, sampleDevice2, sampleDevice3, sampleDevice4);
final byte sampleDeviceId = 1;
final byte sampleDevice2Id = 2;
final byte sampleDevice3Id = 3;
final byte sampleDevice4Id = 4;
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
@ -199,18 +209,18 @@ class KeysControllerTest {
when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2);
when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3);
when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null);
when(sampleDevice.getId()).thenReturn(1L);
when(sampleDevice2.getId()).thenReturn(2L);
when(sampleDevice3.getId()).thenReturn(3L);
when(sampleDevice4.getId()).thenReturn(4L);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
when(sampleDevice4.getId()).thenReturn(sampleDevice4Id);
when(existsAccount.getUuid()).thenReturn(EXISTS_UUID);
when(existsAccount.getPhoneNumberIdentifier()).thenReturn(EXISTS_PNI);
when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.empty());
when(existsAccount.getDevice(sampleDeviceId)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(sampleDevice2Id)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(sampleDevice3Id)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(sampleDevice4Id)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice((byte) 22)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
@ -225,17 +235,21 @@ class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.store(any(), anyLong(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.store(any(), anyByte(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takeEC(EXISTS_PNI, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY_PNI)));
when(KEYS.takePQ(EXISTS_PNI, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY_PNI)));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, 1)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));
when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
@ -267,8 +281,8 @@ class KeysControllerTest {
assertThat(result.getCount()).isEqualTo(5);
assertThat(result.getPqCount()).isEqualTo(5);
verify(KEYS).getEcCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, 1);
verify(KEYS).getEcCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getPqCount(AuthHelper.VALID_UUID, SAMPLE_DEVICE_ID);
}
@Test
@ -284,7 +298,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(Device.PRIMARY_ID, test));
}
@ -303,7 +317,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey));
verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any());
verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(Device.PRIMARY_ID, replacementKey));
}
@ -329,20 +343,20 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@Test
void validSingleRequestPqTestNoPqKeysV2() {
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
@ -353,15 +367,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -376,15 +390,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -398,14 +412,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -420,15 +434,15 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).takePQ(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -444,14 +458,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(1).getPreKey());
assertThat(result.getDevice(1).getPqPreKey()).isNull();
assertThat(result.getDevice(1).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_PNI, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, 1);
verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -481,14 +495,14 @@ class KeysControllerTest {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(1).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(1).getPqPreKey());
assertEquals(existsAccount.getDevice(1).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(1).getSignedPreKey());
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
@ -534,10 +548,14 @@ class KeysControllerTest {
@Test
void validMultiRequestTestV2() {
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY2)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@ -548,56 +566,62 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey();
long registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId();
byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId();
assertEquals(SAMPLE_KEY, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId();
assertEquals(SAMPLE_KEY2, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId();
assertEquals(SAMPLE_KEY4, preKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS);
}
@Test
void validMultiRequestPqTestV2() {
when(KEYS.takeEC(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takePQ(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.takeEC(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, 4)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, 1)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, 2)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, 3)).thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY3)));
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
@ -609,51 +633,51 @@ class KeysControllerTest {
assertThat(results.getDevicesCount()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
ECSignedPreKey signedPreKey = results.getDevice(1).getSignedPreKey();
ECPreKey preKey = results.getDevice(1).getPreKey();
KEMSignedPreKey pqPreKey = results.getDevice(1).getPqPreKey();
long registrationId = results.getDevice(1).getRegistrationId();
long deviceId = results.getDevice(1).getDeviceId();
ECSignedPreKey signedPreKey = results.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey();
ECPreKey preKey = results.getDevice(SAMPLE_DEVICE_ID).getPreKey();
KEMSignedPreKey pqPreKey = results.getDevice(SAMPLE_DEVICE_ID).getPqPreKey();
int registrationId = results.getDevice(SAMPLE_DEVICE_ID).getRegistrationId();
byte deviceId = results.getDevice(SAMPLE_DEVICE_ID).getDeviceId();
assertEquals(SAMPLE_KEY, preKey);
assertEquals(SAMPLE_PQ_KEY, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, signedPreKey);
assertThat(deviceId).isEqualTo(1);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID);
signedPreKey = results.getDevice(2).getSignedPreKey();
preKey = results.getDevice(2).getPreKey();
pqPreKey = results.getDevice(2).getPqPreKey();
registrationId = results.getDevice(2).getRegistrationId();
deviceId = results.getDevice(2).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID2).getPreKey();
pqPreKey = results.getDevice(SAMPLE_DEVICE_ID2).getPqPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID2).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID2).getDeviceId();
assertThat(preKey).isNull();
assertEquals(SAMPLE_PQ_KEY2, pqPreKey);
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertEquals(SAMPLE_SIGNED_KEY2, signedPreKey);
assertThat(deviceId).isEqualTo(2);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID2);
signedPreKey = results.getDevice(4).getSignedPreKey();
preKey = results.getDevice(4).getPreKey();
pqPreKey = results.getDevice(4).getPqPreKey();
registrationId = results.getDevice(4).getRegistrationId();
deviceId = results.getDevice(4).getDeviceId();
signedPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getSignedPreKey();
preKey = results.getDevice(SAMPLE_DEVICE_ID4).getPreKey();
pqPreKey = results.getDevice(SAMPLE_DEVICE_ID4).getPqPreKey();
registrationId = results.getDevice(SAMPLE_DEVICE_ID4).getRegistrationId();
deviceId = results.getDevice(SAMPLE_DEVICE_ID4).getDeviceId();
assertEquals(SAMPLE_KEY4, preKey);
assertThat(pqPreKey).isNull();
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
assertThat(deviceId).isEqualTo(SAMPLE_DEVICE_ID4);
verify(KEYS).takeEC(EXISTS_UUID, 1);
verify(KEYS).takePQ(EXISTS_UUID, 1);
verify(KEYS).takeEC(EXISTS_UUID, 2);
verify(KEYS).takePQ(EXISTS_UUID, 2);
verify(KEYS).takeEC(EXISTS_UUID, 4);
verify(KEYS).takePQ(EXISTS_UUID, 4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 1);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, 4);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verifyNoMoreInteractions(KEYS);
}
@ -719,7 +743,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@ -750,7 +775,8 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
@ -852,7 +878,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), eq(signedPreKey),
isNull());
assertThat(listCaptor.getValue()).containsExactly(preKey);
@ -884,7 +911,8 @@ class KeysControllerTest {
ArgumentCaptor<List<ECPreKey>> ecCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<KEMSignedPreKey>> pqCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(1L), ecCaptor.capture(), pqCaptor.capture(), eq(signedPreKey), eq(pqLastResortPreKey));
verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(),
eq(signedPreKey), eq(pqLastResortPreKey));
assertThat(ecCaptor.getValue()).containsExactly(preKey);
assertThat(pqCaptor.getValue()).containsExactly(pqPreKey);
@ -928,7 +956,8 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List<ECPreKey>> listCaptor = ArgumentCaptor.forClass(List.class);
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture(), isNull(), eq(signedPreKey), isNull());
verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(),
eq(signedPreKey), isNull());
List<ECPreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@ -953,7 +982,8 @@ class KeysControllerTest {
resources.getJerseyTest()
.target("/v2/keys")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, 2L, AuthHelper.VALID_PASSWORD_3_LINKED))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_3, SAMPLE_DEVICE_ID2,
AuthHelper.VALID_PASSWORD_3_LINKED))
.put(Entity.entity(preKeyState, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);

View File

@ -135,15 +135,15 @@ class MessageControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final UUID SINGLE_DEVICE_UUID = UUID.fromString("11111111-1111-1111-1111-111111111111");
private static final UUID SINGLE_DEVICE_PNI = UUID.fromString("11111111-0000-0000-0000-111111111111");
private static final int SINGLE_DEVICE_ID1 = 1;
private static final byte SINGLE_DEVICE_ID1 = 1;
private static final int SINGLE_DEVICE_REG_ID1 = 111;
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private static final UUID MULTI_DEVICE_UUID = UUID.fromString("22222222-2222-2222-2222-222222222222");
private static final UUID MULTI_DEVICE_PNI = UUID.fromString("22222222-0000-0000-0000-222222222222");
private static final int MULTI_DEVICE_ID1 = 1;
private static final int MULTI_DEVICE_ID2 = 2;
private static final int MULTI_DEVICE_ID3 = 3;
private static final byte MULTI_DEVICE_ID1 = 1;
private static final byte MULTI_DEVICE_ID2 = 2;
private static final byte MULTI_DEVICE_ID3 = 3;
private static final int MULTI_DEVICE_REG_ID1 = 222;
private static final int MULTI_DEVICE_REG_ID2 = 333;
private static final int MULTI_DEVICE_REG_ID3 = 444;
@ -225,7 +225,8 @@ class MessageControllerTest {
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);
}
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
final ECSignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);
@ -526,13 +527,14 @@ class MessageControllerTest {
final UUID updatedPniOne = UUID.randomUUID();
List<Envelope> envelopes = List.of(
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, 2,
generateEnvelope(messageGuidOne, Envelope.Type.CIPHERTEXT_VALUE, timestampOne, sourceUuid, (byte) 2,
AuthHelper.VALID_UUID, updatedPniOne, "hi there".getBytes(), 0, false),
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid, 2,
generateEnvelope(messageGuidTwo, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo, sourceUuid,
(byte) 2,
AuthHelper.VALID_UUID, null, null, 0, true)
);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean()))
.thenReturn(Mono.just(new Pair<>(envelopes, false)));
final String userAgent = "Test-UA";
@ -580,13 +582,13 @@ class MessageControllerTest {
final long timestampTwo = 313388;
final List<Envelope> messages = List.of(
generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), 2,
generateEnvelope(UUID.randomUUID(), Envelope.Type.CIPHERTEXT_VALUE, timestampOne, UUID.randomUUID(), (byte) 2,
AuthHelper.VALID_UUID, null, "hi there".getBytes(), 0),
generateEnvelope(UUID.randomUUID(), Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE, timestampTwo,
UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)
UUID.randomUUID(), (byte) 2, AuthHelper.VALID_UUID, null, null, 0)
);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq((byte) 1), anyBoolean()))
.thenReturn(Mono.just(new Pair<>(messages, false)));
Response response =
@ -606,24 +608,24 @@ class MessageControllerTest {
UUID sourceUuid = UUID.randomUUID();
UUID uuid1 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid1, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid1, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(uuid1, Envelope.Type.CIPHERTEXT_VALUE,
timestamp, sourceUuid, 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
timestamp, sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, "hi".getBytes(), 0))));
UUID uuid2 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid2, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid2, null))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(generateEnvelope(
uuid2, Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE,
System.currentTimeMillis(), sourceUuid, 1, AuthHelper.VALID_UUID, null, null, 0))));
System.currentTimeMillis(), sourceUuid, (byte) 1, AuthHelper.VALID_UUID, null, null, 0))));
UUID uuid3 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid3, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid3, null))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
UUID uuid4 = UUID.randomUUID();
when(messagesManager.delete(AuthHelper.VALID_UUID, 1, uuid4, null))
when(messagesManager.delete(AuthHelper.VALID_UUID, (byte) 1, uuid4, null))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("Oh No")));
Response response = resources.getJerseyTest()
@ -633,7 +635,7 @@ class MessageControllerTest {
.delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq(1L),
verify(receiptSender).sendReceipt(eq(new AciServiceIdentifier(AuthHelper.VALID_UUID)), eq((byte) 1),
eq(new AciServiceIdentifier(sourceUuid)), eq(timestamp));
response = resources.getJerseyTest()
@ -879,7 +881,7 @@ class MessageControllerTest {
.request()
.header(OptionalAccess.UNIDENTIFIED, Base64.getEncoder().encodeToString(UNIDENTIFIED_ACCESS_BYTES))
.put(Entity.entity(new IncomingMessageList(
List.of(new IncomingMessage(1, 1L, 1, new String(contentBytes))), false, true,
List.of(new IncomingMessage(1, (byte) 1, 1, new String(contentBytes))), false, true,
System.currentTimeMillis()),
MediaType.APPLICATION_JSON_TYPE));
@ -919,7 +921,7 @@ class MessageControllerTest {
);
}
private static void writePayloadDeviceId(ByteBuffer bb, long deviceId) {
private static void writePayloadDeviceId(ByteBuffer bb, byte deviceId) {
long x = deviceId;
// write the device-id in the 7-bit varint format we use, least significant bytes first.
do {
@ -1155,7 +1157,7 @@ class MessageControllerTest {
if (known) {
r1 = new Recipient(new AciServiceIdentifier(SINGLE_DEVICE_UUID), SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, new byte[48]);
} else {
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), 999, 999, new byte[48]);
r1 = new Recipient(new AciServiceIdentifier(UUID.randomUUID()), (byte) 99, 999, new byte[48]);
}
Recipient r2 = new Recipient(new AciServiceIdentifier(MULTI_DEVICE_UUID), MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, new byte[48]);
@ -1250,7 +1252,7 @@ class MessageControllerTest {
SystemMapper.jsonMapper().getTypeFactory().constructCollectionType(List.class, AccountMismatchedDevices.class));
assertEquals(List.of(new AccountMismatchedDevices(serviceIdentifier,
new MismatchedDevices(Collections.emptyList(), List.of((long) MULTI_DEVICE_ID3)))),
new MismatchedDevices(Collections.emptyList(), List.of(MULTI_DEVICE_ID3)))),
mismatchedDevices);
}
@ -1298,7 +1300,8 @@ class MessageControllerTest {
assertEquals(1, staleDevices.size());
assertEquals(serviceIdentifier, staleDevices.get(0).uuid());
assertEquals(Set.of((long) MULTI_DEVICE_ID1, (long) MULTI_DEVICE_ID2), new HashSet<>(staleDevices.get(0).devices().staleDevices()));
assertEquals(Set.of(MULTI_DEVICE_ID1, MULTI_DEVICE_ID2),
new HashSet<>(staleDevices.get(0).devices().staleDevices()));
}
private static Stream<Arguments> sendMultiRecipientMessageStaleDevices() {
@ -1380,12 +1383,12 @@ class MessageControllerTest {
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp) {
return generateEnvelope(guid, type, timestamp, sourceUuid, sourceDevice, destinationUuid, updatedPni, content, serverTimestamp, false);
}
private static Envelope generateEnvelope(UUID guid, int type, long timestamp, UUID sourceUuid,
int sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
byte sourceDevice, UUID destinationUuid, UUID updatedPni, byte[] content, long serverTimestamp, boolean story) {
final MessageProtos.Envelope.Builder builder = MessageProtos.Envelope.newBuilder()
.setType(MessageProtos.Envelope.Type.forNumber(type))
@ -1413,14 +1416,14 @@ class MessageControllerTest {
private static Recipient genRecipient(Random rng) {
UUID u1 = UUID.randomUUID(); // non-null
long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903
byte d1 = (byte) (rng.nextInt(127) + 1); // 1 to 127
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
byte[] perKeyBytes = new byte[48]; // size=48, non-null
rng.nextBytes(perKeyBytes);
return new Recipient(new AciServiceIdentifier(u1), d1, dr1, perKeyBytes);
}
private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
private static void roundTripVarint(byte expected, byte[] bytes) throws Exception {
ByteBuffer bb = ByteBuffer.wrap(bytes);
writePayloadDeviceId(bb, expected);
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
@ -1434,15 +1437,17 @@ class MessageControllerTest {
byte[] bytes = new byte[12];
// some static test cases
for (long i = 1L; i <= 10L; i++) {
for (byte i = 1; i <= 10; i++) {
roundTripVarint(i, bytes);
}
roundTripVarint(Long.MAX_VALUE, bytes);
roundTripVarint(Byte.MAX_VALUE, bytes);
for (int i = 0; i < 1000; i++) {
// we need to ensure positive device IDs
long start = rng.nextLong() & Long.MAX_VALUE;
if (start == 0L) start = 1L;
byte start = (byte) rng.nextInt(128);
if (start == 0L) {
start = 1;
}
// run the test for this case
roundTripVarint(start, bytes);

View File

@ -75,12 +75,12 @@ class OutgoingMessageEntityTest {
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, 4444L, 55, "AAAAAA");
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, "AAAAAA");
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),
account,
123L,
(byte) 123,
System.currentTimeMillis(),
false,
true,

View File

@ -170,7 +170,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void deleteAccountLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
@ -215,7 +215,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void setRegistrationLockLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,
@ -240,7 +240,7 @@ class AccountsGrpcServiceTest extends SimpleBaseGrpcTest<AccountsGrpcService, Ac
@Test
void clearRegistrationLockLinkedDevice() {
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
getMockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
//noinspection ResultOfMethodCallIgnored
GrpcTestUtils.assertStatusException(Status.PERMISSION_DENIED,

View File

@ -7,7 +7,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
@ -88,7 +88,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
return CompletableFuture.completedFuture(account);
});
when(accountsManager.updateDeviceAsync(any(), anyLong(), any()))
when(accountsManager.updateDeviceAsync(any(), anyByte(), any()))
.thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final Device device = account.getDevice(invocation.getArgument(1)).orElseThrow();
@ -99,8 +99,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
return CompletableFuture.completedFuture(account);
});
when(keysManager.delete(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
return new DevicesGrpcService(accountsManager, keysManager, messagesManager);
}
@ -120,7 +120,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final String linkedDeviceName = "A linked device";
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn(Device.PRIMARY_ID + 1);
when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(linkedDevice.getCreated()).thenReturn(linkedDeviceCreated.toEpochMilli());
when(linkedDevice.getLastSeen()).thenReturn(linkedDeviceLastSeen.toEpochMilli());
when(linkedDevice.getName())
@ -147,7 +147,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@Test
void removeDevice() {
final long deviceId = 17;
final byte deviceId = 17;
final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId)
@ -167,15 +167,15 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@Test
void removeDeviceNonPrimaryAuthenticated() {
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, Device.PRIMARY_ID + 1);
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, (byte) (Device.PRIMARY_ID + 1));
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(17)
.build()));
}
@ParameterizedTest
@ValueSource(longs = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1})
void setDeviceName(final long deviceId) {
@ValueSource(bytes = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1})
void setDeviceName(final byte deviceId) {
mockAuthenticationInterceptor().setAuthenticatedDevice(AUTHENTICATED_ACI, deviceId);
final Device device = mock(Device.class);
@ -212,7 +212,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@ParameterizedTest
@MethodSource
void setPushToken(final long deviceId,
void setPushToken(final byte deviceId,
final SetPushTokenRequest request,
@Nullable final String expectedApnsToken,
@Nullable final String expectedApnsVoipToken,
@ -238,7 +238,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final Stream.Builder<Arguments> streamBuilder = Stream.builder();
for (final long deviceId : new long[] { Device.PRIMARY_ID, Device.PRIMARY_ID + 1 }) {
for (final byte deviceId : new byte[]{Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) {
streamBuilder.add(Arguments.of(deviceId,
SetPushTokenRequest.newBuilder()
.setApnsTokenRequest(SetPushTokenRequest.ApnsTokenRequest.newBuilder()
@ -284,7 +284,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final SetPushTokenResponse ignored = authenticatedServiceStub().setPushToken(request);
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
}
private static Stream<Arguments> setPushTokenUnchanged() {
@ -323,7 +323,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
final Device device = mock(Device.class);
when(authenticatedAccount.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().setPushToken(request));
verify(accountsManager, never()).updateDevice(any(), anyLong(), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
}
private static Stream<Arguments> setPushTokenIllegalArgument() {
@ -342,7 +342,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
@ParameterizedTest
@MethodSource
void clearPushToken(final long deviceId,
void clearPushToken(final byte deviceId,
@Nullable final String apnsToken,
@Nullable final String apnsVoipToken,
@Nullable final String fcmToken,
@ -379,17 +379,17 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
Arguments.of(Device.PRIMARY_ID, null, "apns-voip-token", null, "OWI"),
Arguments.of(Device.PRIMARY_ID, null, null, "fcm-token", "OWA"),
Arguments.of(Device.PRIMARY_ID, null, null, null, null),
Arguments.of(Device.PRIMARY_ID + 1, "apns-token", null, null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, "apns-token", "apns-voip-token", null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, null, "apns-voip-token", null, "OWP"),
Arguments.of(Device.PRIMARY_ID + 1, null, null, "fcm-token", "OWA"),
Arguments.of(Device.PRIMARY_ID + 1, null, null, null, null)
Arguments.of((byte) (Device.PRIMARY_ID + 1), "apns-token", null, null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), "apns-token", "apns-voip-token", null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, "apns-voip-token", null, "OWP"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, null, "fcm-token", "OWA"),
Arguments.of((byte) (Device.PRIMARY_ID + 1), null, null, null, null)
);
}
@CartesianTest
void setCapabilities(
@CartesianTest.Values(longs = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) final long deviceId,
@CartesianTest.Values(bytes = {Device.PRIMARY_ID, Device.PRIMARY_ID + 1}) final byte deviceId,
@CartesianTest.Values(booleans = {true, false}) final boolean storage,
@CartesianTest.Values(booleans = {true, false}) final boolean transfer,
@CartesianTest.Values(booleans = {true, false}) final boolean pni,

View File

@ -31,7 +31,7 @@ public final class GrpcTestUtils {
final MockAuthenticationInterceptor mockAuthenticationInterceptor,
final MockRemoteAddressInterceptor mockRemoteAddressInterceptor,
final UUID authenticatedAci,
final long authenticatedDeviceId,
final byte authenticatedDeviceId,
final BindableService service) {
mockAuthenticationInterceptor.setAuthenticatedDevice(authenticatedAci, authenticatedDeviceId);
extension.getServiceRegistry()

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.grpc.GrpcTestUtils.assertStatusException;
@ -184,8 +184,8 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
}
@ParameterizedTest
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final long deviceId) {
@ValueSource(bytes = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final byte deviceId) {
final UUID accountIdentifier = UUID.randomUUID();
final byte[] unidentifiedAccessKey = new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH];
@ -195,7 +195,7 @@ class KeysAnonymousGrpcServiceTest extends SimpleBaseGrpcTest<KeysAnonymousGrpcS
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(targetAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of(unidentifiedAccessKey));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.grpc;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
@ -151,7 +151,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
preKeys.add(new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()));
}
when(keysManager.storeEcOneTimePreKeys(any(), anyLong(), any()))
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
@ -222,7 +222,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
preKeys.add(KeysHelper.signedKEMPreKey(keyId, identityKeyPair));
}
when(keysManager.storeKemOneTimePreKeys(any(), anyLong(), any()))
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
//noinspection ResultOfMethodCallIgnored
@ -294,9 +294,9 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
@ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void setSignedPreKey(final org.signal.chat.common.IdentityType identityType) {
when(accountsManager.updateDeviceAsync(any(), anyLong(), any())).thenAnswer(invocation -> {
when(accountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(invocation -> {
final Account account = invocation.getArgument(0);
final long deviceId = invocation.getArgument(1);
final byte deviceId = invocation.getArgument(1);
final Consumer<Device> deviceUpdater = invocation.getArgument(2);
account.getDevice(deviceId).ifPresent(deviceUpdater);
@ -477,13 +477,16 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
when(accountsManager.getByServiceIdentifierAsync(argThat(serviceIdentifier -> serviceIdentifier.uuid().equals(identifier))))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
final Map<Long, ECPreKey> ecOneTimePreKeys = new HashMap<>();
final Map<Long, KEMSignedPreKey> kemPreKeys = new HashMap<>();
final Map<Long, ECSignedPreKey> ecSignedPreKeys = new HashMap<>();
final Map<Byte, ECPreKey> ecOneTimePreKeys = new HashMap<>();
final Map<Byte, KEMSignedPreKey> kemPreKeys = new HashMap<>();
final Map<Byte, ECSignedPreKey> ecSignedPreKeys = new HashMap<>();
final Map<Long, Device> devices = new HashMap<>();
final Map<Byte, Device> devices = new HashMap<>();
for (final long deviceId : List.of(1, 2)) {
final byte deviceId1 = 1;
final byte deviceId2 = 2;
for (final byte deviceId : List.of(deviceId1, deviceId2)) {
ecOneTimePreKeys.put(deviceId, new ECPreKey(1, Curve.generateKeyPair().getPublicKey()));
kemPreKeys.put(deviceId, KeysHelper.signedKEMPreKey(2, identityKeyPair));
ecSignedPreKeys.put(deviceId, KeysHelper.signedECPreKey(3, identityKeyPair));
@ -518,18 +521,18 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).signature()))
.build())
.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
.setKeyId(ecOneTimePreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(deviceId1).serializedPublicKey()))
.build())
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
.setKeyId(kemPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(deviceId1).signature()))
.build())
.build())
.build();
@ -537,8 +540,8 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
assertEquals(expectedResponse, response);
}
when(keysManager.takeEC(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takePQ(identifier, 2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takeEC(identifier, deviceId2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(keysManager.takePQ(identifier, deviceId2)).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
{
final GetPreKeysResponse response = authenticatedServiceStub().getPreKeys(GetPreKeysRequest.newBuilder()
@ -552,25 +555,25 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
.setIdentityKey(ByteString.copyFrom(identityKey.serialize()))
.putPreKeys(1, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(1L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId1).signature()))
.build())
.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecOneTimePreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(1L).serializedPublicKey()))
.setKeyId(ecOneTimePreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(ecOneTimePreKeys.get(deviceId1).serializedPublicKey()))
.build())
.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemPreKeys.get(1L).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(1L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(1L).signature()))
.setKeyId(kemPreKeys.get(deviceId1).keyId())
.setPublicKey(ByteString.copyFrom(kemPreKeys.get(deviceId1).serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemPreKeys.get(deviceId1).signature()))
.build())
.build())
.putPreKeys(2, GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKeys.get(2L).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(2L).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(2L).signature()))
.setKeyId(ecSignedPreKeys.get(deviceId2).keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKeys.get(deviceId2).serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKeys.get(deviceId2).signature()))
.build())
.build())
.build();
@ -593,15 +596,15 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
}
@ParameterizedTest
@ValueSource(longs = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final long deviceId) {
@ValueSource(bytes = {KeysGrpcHelper.ALL_DEVICES, 1})
void getPreKeysDeviceNotFound(final byte deviceId) {
final UUID accountIdentifier = UUID.randomUUID();
final Account targetAccount = mock(Account.class);
when(targetAccount.getUuid()).thenReturn(accountIdentifier);
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(accountIdentifier)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));
@ -621,7 +624,7 @@ class KeysGrpcServiceTest extends SimpleBaseGrpcTest<KeysGrpcService, KeysGrpc.K
when(targetAccount.getUuid()).thenReturn(UUID.randomUUID());
when(targetAccount.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(Curve.generateKeyPair().getPublicKey()));
when(targetAccount.getDevices()).thenReturn(Collections.emptyList());
when(targetAccount.getDevice(anyLong())).thenReturn(Optional.empty());
when(targetAccount.getDevice(anyByte())).thenReturn(Optional.empty());
when(accountsManager.getByServiceIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(targetAccount)));

View File

@ -55,7 +55,7 @@ public abstract class SimpleBaseGrpcTest<SERVICE extends BindableService, STUB e
protected static final UUID AUTHENTICATED_ACI = UUID.randomUUID();
protected static final long AUTHENTICATED_DEVICE_ID = Device.PRIMARY_ID;
protected static final byte AUTHENTICATED_DEVICE_ID = Device.PRIMARY_ID;
private AutoCloseable mocksCloseable;

View File

@ -54,7 +54,7 @@ class APNSenderTest {
apnsClient = mock(ApnsClient.class);
apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, BUNDLE_ID);
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(destinationAccount.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(destinationDevice));
when(destinationDevice.getApnId()).thenReturn(DESTINATION_DEVICE_TOKEN);
}

View File

@ -30,7 +30,6 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
@ -54,7 +53,7 @@ class ApnPushNotificationSchedulerTest {
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32);
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@ -98,12 +97,12 @@ class ApnPushNotificationSchedulerTest {
final List<String> pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Optional<Pair<String, Long>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
final Optional<Pair<String, Byte>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
pendingDestinations.get(0));
assertTrue(maybeUuidAndDeviceId.isPresent());
assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first());
assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second());
assertEquals(DEVICE_ID, maybeUuidAndDeviceId.get().second());
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
@ -236,8 +235,6 @@ class ApnPushNotificationSchedulerTest {
final AccountsManager accountsManager = mock(AccountsManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
apnPushNotificationScheduler = new ApnPushNotificationScheduler(redisCluster, apnSender,
accountsManager, dedicatedThreadCount);

View File

@ -76,7 +76,7 @@ class ClientPresenceManagerTest {
@Test
void testIsPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
@ -87,7 +87,7 @@ class ClientPresenceManagerTest {
@Test
void testIsLocallyPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
@ -100,7 +100,7 @@ class ClientPresenceManagerTest {
@Test
void testLocalDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final AtomicInteger displacementCounter = new AtomicInteger(0);
final DisplacedPresenceListener displacementListener = connectedElsewhere -> displacementCounter.incrementAndGet();
@ -117,7 +117,7 @@ class ClientPresenceManagerTest {
@Test
void testRemoteDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
@ -135,7 +135,7 @@ class ClientPresenceManagerTest {
@Test
void testRemoteDisplacementAfterTopologyChange() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
@ -157,7 +157,7 @@ class ClientPresenceManagerTest {
@Test
void testClearPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
@ -210,7 +210,7 @@ class ClientPresenceManagerTest {
@Test
void testInitialPresenceExpiration() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
@ -225,7 +225,7 @@ class ClientPresenceManagerTest {
@Test
void testRenewPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final String presenceKey = ClientPresenceManager.getPresenceKey(accountUuid, deviceId);
@ -252,7 +252,7 @@ class ClientPresenceManagerTest {
@Test
void testExpiredPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
@ -266,7 +266,7 @@ class ClientPresenceManagerTest {
}
private void addClientPresence(final String managerId) {
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7);
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), (byte) 7);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().set(clientPresenceKey, managerId);
@ -278,17 +278,17 @@ class ClientPresenceManagerTest {
void testClearAllOnStop() {
final int localAccounts = 10;
final UUID[] localUuids = new UUID[localAccounts];
final long[] localDeviceIds = new long[localAccounts];
final byte[] localDeviceIds = new byte[localAccounts];
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
localDeviceIds[i] = (byte) i;
clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP);
}
final UUID displacedAccountUuid = UUID.randomUUID();
final long displacedAccountDeviceId = 7;
final byte displacedAccountDeviceId = 7;
clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> connection.sync()
@ -299,7 +299,7 @@ class ClientPresenceManagerTest {
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
localDeviceIds[i] = (byte) i;
assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i]));
}
@ -346,7 +346,7 @@ class ClientPresenceManagerTest {
@Test
void testSetPresentRemotely() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);
@ -360,7 +360,7 @@ class ClientPresenceManagerTest {
@Test
void testDisconnectPresenceLocally() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);
@ -374,7 +374,7 @@ class ClientPresenceManagerTest {
@Test
void testDisconnectPresenceRemotely() {
final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1L;
final byte deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = connectedElsewhere -> displaced.complete(null);

View File

@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
@ -42,7 +42,7 @@ class MessageSenderTest {
private MessageSender messageSender;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
@BeforeEach
void setUp() {
@ -73,7 +73,7 @@ class MessageSenderTest {
ArgumentCaptor<MessageProtos.Envelope> envelopeArgumentCaptor = ArgumentCaptor.forClass(
MessageProtos.Envelope.class);
verify(messagesManager).insert(any(), anyLong(), envelopeArgumentCaptor.capture());
verify(messagesManager).insert(any(), anyByte(), envelopeArgumentCaptor.capture());
assertTrue(envelopeArgumentCaptor.getValue().getEphemeral());
@ -87,7 +87,7 @@ class MessageSenderTest {
messageSender.sendMessage(account, device, message, true);
verify(messagesManager, never()).insert(any(), anyLong(), any());
verify(messagesManager, never()).insert(any(), anyByte(), any());
verifyNoInteractions(pushNotificationManager);
}

View File

@ -1,6 +1,16 @@
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import com.google.protobuf.ByteString;
import java.time.Duration;
import java.util.Random;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@ -11,17 +21,6 @@ import org.whispersystems.textsecuregcm.redis.RedisSingletonExtension;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import java.time.Duration;
import java.util.Random;
import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.after;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
class ProvisioningManagerTest {
private ProvisioningManager provisioningManager;
@ -44,7 +43,7 @@ class ProvisioningManagerTest {
@Test
void sendProvisioningMessage() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);
@ -65,7 +64,7 @@ class ProvisioningManagerTest {
@Test
void removeListener() {
final ProvisioningAddress address = new ProvisioningAddress("address", 0);
final ProvisioningAddress address = new ProvisioningAddress("address", (byte) 0);
final byte[] content = new byte[16];
new Random().nextBytes(content);

View File

@ -35,7 +35,7 @@ class PushLatencyManagerTest {
@MethodSource
void testTakeRecord(final boolean isVoip, final boolean isUrgent) throws ExecutionException, InterruptedException {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final byte deviceId = 1;
final Instant pushTimestamp = Instant.now();

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -85,15 +86,16 @@ class AccountTest {
when(agingSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31));
when(agingSecondaryDevice.isEnabled()).thenReturn(false);
when(agingSecondaryDevice.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(agingSecondaryDevice.getId()).thenReturn(deviceId2);
when(recentSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(1));
when(recentSecondaryDevice.isEnabled()).thenReturn(true);
when(recentSecondaryDevice.getId()).thenReturn(2L);
when(recentSecondaryDevice.getId()).thenReturn(deviceId2);
when(oldSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldSecondaryDevice.isEnabled()).thenReturn(false);
when(oldSecondaryDevice.getId()).thenReturn(2L);
when(oldSecondaryDevice.getId()).thenReturn(deviceId2);
when(senderKeyCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
@ -143,17 +145,17 @@ class AccountTest {
new DeviceCapabilities(true, true, false, false));
when(pniIncapableExpiredDevice.isEnabled()).thenReturn(false);
when(storiesCapableDevice.getId()).thenReturn(1L);
when(storiesCapableDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(storiesCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesCapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(2L);
when(storiesCapableDevice.getId()).thenReturn(deviceId2);
when(storiesIncapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesIncapableDevice.isEnabled()).thenReturn(true);
when(storiesCapableDevice.getId()).thenReturn(3L);
when(storiesCapableDevice.getId()).thenReturn((byte) 3);
when(storiesIncapableExpiredDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false, false));
when(storiesIncapableExpiredDevice.isEnabled()).thenReturn(false);
@ -192,10 +194,11 @@ class AccountTest {
when(disabledPrimaryDevice.isEnabled()).thenReturn(false);
when(disabledLinkedDevice.isEnabled()).thenReturn(false);
when(enabledPrimaryDevice.getId()).thenReturn(1L);
when(enabledLinkedDevice.getId()).thenReturn(2L);
when(disabledPrimaryDevice.getId()).thenReturn(1L);
when(disabledLinkedDevice.getId()).thenReturn(2L);
when(enabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
final byte deviceId2 = 2;
when(enabledLinkedDevice.getId()).thenReturn(deviceId2);
when(disabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(disabledLinkedDevice.getId()).thenReturn(deviceId2);
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice)).isEnabled());
assertTrue(AccountsHelper.generateTestAccount("+14151234567", List.of(enabledPrimaryDevice, enabledLinkedDevice)).isEnabled());
@ -214,15 +217,15 @@ class AccountTest {
final DeviceCapabilities transferCapabilities = mock(DeviceCapabilities.class);
final DeviceCapabilities nonTransferCapabilities = mock(DeviceCapabilities.class);
when(transferCapablePrimaryDevice.getId()).thenReturn(1L);
when(transferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(transferCapablePrimaryDevice.isPrimary()).thenReturn(true);
when(transferCapablePrimaryDevice.getCapabilities()).thenReturn(transferCapabilities);
when(nonTransferCapablePrimaryDevice.getId()).thenReturn(1L);
when(nonTransferCapablePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(nonTransferCapablePrimaryDevice.isPrimary()).thenReturn(true);
when(nonTransferCapablePrimaryDevice.getCapabilities()).thenReturn(nonTransferCapabilities);
when(transferCapableLinkedDevice.getId()).thenReturn(2L);
when(transferCapableLinkedDevice.getId()).thenReturn((byte) 2);
when(transferCapableLinkedDevice.isPrimary()).thenReturn(false);
when(transferCapableLinkedDevice.getCapabilities()).thenReturn(transferCapabilities);
@ -311,21 +314,31 @@ class AccountTest {
final Account account = AccountsHelper.generateTestAccount("+14151234567", UUID.randomUUID(), UUID.randomUUID(), devices, new byte[0]);
assertThat(account.getNextDeviceId()).isEqualTo(2L);
final byte deviceId2 = 2;
assertThat(account.getNextDeviceId()).isEqualTo(deviceId2);
account.addDevice(createDevice(2L));
account.addDevice(createDevice(deviceId2));
assertThat(account.getNextDeviceId()).isEqualTo(3L);
final byte deviceId3 = 3;
assertThat(account.getNextDeviceId()).isEqualTo(deviceId3);
account.addDevice(createDevice(3L));
account.addDevice(createDevice(deviceId3));
setEnabled(account.getDevice(2L).orElseThrow(), false);
setEnabled(account.getDevice(deviceId2).orElseThrow(), false);
assertThat(account.getNextDeviceId()).isEqualTo(4L);
assertThat(account.getNextDeviceId()).isEqualTo((byte) 4);
account.removeDevice(2L);
account.removeDevice(deviceId2);
assertThat(account.getNextDeviceId()).isEqualTo(2L);
assertThat(account.getNextDeviceId()).isEqualTo(deviceId2);
while (account.getNextDeviceId() < Device.MAXIMUM_DEVICE_ID) {
account.addDevice(createDevice(account.getNextDeviceId()));
}
account.addDevice(createDevice(Device.MAXIMUM_DEVICE_ID));
assertThatThrownBy(account::getNextDeviceId).isInstanceOf(RuntimeException.class);
}
@Test
@ -399,7 +412,7 @@ class AccountTest {
final Device disabledPrimary = mock(Device.class);
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);
final long linked1DeviceId = Device.PRIMARY_ID + 1;
final byte linked1DeviceId = Device.PRIMARY_ID + 1;
final Device enabledLinked1 = mock(Device.class);
when(enabledLinked1.isEnabled()).thenReturn(true);
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);
@ -407,7 +420,7 @@ class AccountTest {
final Device disabledLinked1 = mock(Device.class);
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);
final long linked2DeviceId = Device.PRIMARY_ID + 2;
final byte linked2DeviceId = Device.PRIMARY_ID + 2;
final Device enabledLinked2 = mock(Device.class);
when(enabledLinked2.isEnabled()).thenReturn(true);
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);

View File

@ -178,8 +178,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Long, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
final Account updatedAccount = accountsManager.changeNumber(account, secondNumber, pniIdentityKey, preKeys, null, registrationIds);

View File

@ -141,8 +141,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()),
a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(1);
a.addDevice(DevicesHelper.createDevice(1));
a.removeDevice(Device.PRIMARY_ID);
a.addDevice(DevicesHelper.createDevice(Device.PRIMARY_ID));
});
uuid = account.getUuid();
@ -212,7 +212,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
}, mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final long deviceId, final Consumer<Device> deviceMutation) {
private CompletableFuture<?> modifyDevice(final UUID uuid, final byte deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.getByAccountIdentifier(uuid).orElseThrow();

View File

@ -876,7 +876,7 @@ class AccountsManagerTest {
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@ -909,7 +909,7 @@ class AccountsManagerTest {
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
final byte deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@ -1064,7 +1064,8 @@ class AccountsManagerTest {
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
assertThrows(IllegalArgumentException.class,
() -> accountsManager.changeNumber(
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()), Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of(1L, 101)),
account, number, new IdentityKey(Curve.generateKeyPair().getPublicKey()),
Map.of(Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, pniIdentityKeyPair)), null, Map.of((byte) 1, 101)),
"AccountsManager should not allow use of changeNumber with new PNI keys but without changing number");
verify(accounts, never()).update(any());
@ -1107,24 +1108,26 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2;
final byte deviceId3 = 3;
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L, 3L)));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID, deviceId3)));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final List<Device> devices = List.of(
DevicesHelper.createDevice(1L, 0L, 101),
DevicesHelper.createDevice(2L, 0L, 102),
DevicesHelper.createDisabledDevice(3L, 103));
DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102),
DevicesHelper.createDisabledDevice(deviceId3, 103));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final Account updatedAccount = accountsManager.changeNumber(
account, targetNumber, new IdentityKey(Curve.generateKeyPair().getPublicKey()), newSignedKeys, newSignedPqKeys, newRegistrationIds);
@ -1140,7 +1143,8 @@ class AccountsManagerTest {
verify(keysManager).delete(originalPni);
verify(keysManager).getPqEnabledDevices(uuid);
verify(keysManager).storeEcSignedPreKeys(newPni, newSignedKeys);
verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(newPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
verifyNoMoreInteractions(keysManager);
}
@ -1153,19 +1157,22 @@ class AccountsManagerTest {
final UUID uuid = UUID.randomUUID();
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final byte deviceId2 = 2;
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair));
final Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
final Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
when(keysManager.getPqEnabledDevices(uuid)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
final List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
assertThrows(MismatchedDevicesException.class,
() -> accountsManager.changeNumber(
@ -1189,18 +1196,20 @@ class AccountsManagerTest {
@Test
void testPniUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1217,7 +1226,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI stuff should
@ -1236,26 +1245,29 @@ class AccountsManagerTest {
@Test
void testPniPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(CompletableFuture.completedFuture(List.of(1L)));
when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(
CompletableFuture.completedFuture(List.of(Device.PRIMARY_ID)));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1270,7 +1282,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
@ -1287,23 +1299,26 @@ class AccountsManagerTest {
verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys);
// only the pq key for the already-pq-enabled device should be saved
verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L))));
verify(keysManager).storePqLastResort(eq(oldPni),
eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID))));
}
@Test
void testPniNonPqToPqUpdate() throws MismatchedDevicesException {
final String number = "+14152222222";
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair),
2L, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair),
deviceId2, KeysHelper.signedKEMPreKey(4, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
@ -1312,7 +1327,7 @@ class AccountsManagerTest {
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1327,7 +1342,7 @@ class AccountsManagerTest {
assertNull(updatedAccount.getIdentityKey(IdentityType.ACI));
assertEquals(oldSignedPreKeys, updatedAccount.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI))));
assertEquals(Map.of(1L, 101, 2L, 102),
assertEquals(Map.of(Device.PRIMARY_ID, 101, deviceId2, 102),
updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getRegistrationId)));
// PNI keys should
@ -1348,19 +1363,21 @@ class AccountsManagerTest {
@Test
void testPniUpdate_incompleteKeys() {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final byte deviceId2 = 2;
final byte deviceId3 = 3;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
2L, KeysHelper.signedECPreKey(1, identityKeyPair),
3L, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
deviceId2, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId3, KeysHelper.signedECPreKey(2, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
@ -1375,21 +1392,22 @@ class AccountsManagerTest {
@Test
void testPniPqUpdate_incompleteKeys() {
final String number = "+14152222222";
List<Device> devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102));
final byte deviceId2 = 2;
List<Device> devices = List.of(DevicesHelper.createDevice(Device.PRIMARY_ID, 0L, 101),
DevicesHelper.createDevice(deviceId2, 0L, 102));
Account account = AccountsHelper.generateTestAccount(number, UUID.randomUUID(), UUID.randomUUID(), devices, new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
final Map<Long, ECSignedPreKey> newSignedKeys = Map.of(
1L, KeysHelper.signedECPreKey(1, identityKeyPair),
2L, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Long, KEMSignedPreKey> newSignedPqKeys = Map.of(
1L, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Long, Integer> newRegistrationIds = Map.of(1L, 201, 2L, 202);
final Map<Byte, ECSignedPreKey> newSignedKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedECPreKey(1, identityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, identityKeyPair));
final Map<Byte, KEMSignedPreKey> newSignedPqKeys = Map.of(
Device.PRIMARY_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair));
Map<Byte, Integer> newRegistrationIds = Map.of(Device.PRIMARY_ID, 201, deviceId2, 202);
UUID oldUuid = account.getUuid();
UUID oldPni = account.getPhoneNumberIdentifier();
Map<Long, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
Map<Byte, ECSignedPreKey> oldSignedPreKeys = account.getDevices().stream()
.collect(Collectors.toMap(Device::getId, d -> d.getSignedPreKey(IdentityType.ACI)));
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());

View File

@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ -75,6 +76,9 @@ import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsTest {
private static final byte DEVICE_ID_1 = 1;
private static final byte DEVICE_ID_2 = 2;
private static final String BASE_64_URL_USERNAME_HASH_1 = "9p6Tip7BFefFOJzv4kv4GyXEYsBVfk_WbjNejdlOvQE";
private static final String BASE_64_URL_USERNAME_HASH_2 = "NLUom-CHwtemcdvOTTXdmXmzRIV7F05leS8lwkVK_vc";
private static final String BASE_64_URL_ENCRYPTED_USERNAME_1 = "md1votbj9r794DsqTNrBqA";
@ -156,7 +160,7 @@ class AccountsTest {
@Test
void testStore() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
@ -179,7 +183,7 @@ class AccountsTest {
void testStoreRecentlyDeleted() {
final UUID originalUuid = UUID.randomUUID();
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
@ -205,7 +209,7 @@ class AccountsTest {
@Test
void testStoreMulti() {
final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
accounts.create(account);
@ -218,13 +222,13 @@ class AccountsTest {
@Test
void testRetrieve() {
final List<Device> devicesFirst = List.of(generateDevice(1), generateDevice(2));
final List<Device> devicesFirst = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
UUID uuidFirst = UUID.randomUUID();
UUID pniFirst = UUID.randomUUID();
Account accountFirst = generateAccount("+14151112222", uuidFirst, pniFirst, devicesFirst);
final List<Device> devicesSecond = List.of(generateDevice(1), generateDevice(2));
final List<Device> devicesSecond = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
UUID uuidSecond = UUID.randomUUID();
UUID pniSecond = UUID.randomUUID();
@ -263,7 +267,7 @@ class AccountsTest {
@Test
void testRetrieveNoPni() throws JsonProcessingException {
final List<Device> devices = List.of(generateDevice(1), generateDevice(2));
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final UUID uuid = UUID.randomUUID();
final Account account = generateAccount("+14151112222", uuid, null, devices);
@ -321,7 +325,7 @@ class AccountsTest {
@Test
void testOverwrite() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
UUID firstUuid = UUID.randomUUID();
UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
@ -346,7 +350,7 @@ class AccountsTest {
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device));
final boolean freshUser = accounts.create(account);
@ -356,7 +360,7 @@ class AccountsTest {
assertPhoneNumberConstraintExists("+14151112222", firstUuid);
assertPhoneNumberIdentifierConstraintExists(firstPni, firstUuid);
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.create(invalidAccount));
@ -364,7 +368,7 @@ class AccountsTest {
@Test
void testUpdate() {
Device device = generateDevice (1 );
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@ -389,7 +393,7 @@ class AccountsTest {
assertThat(retrieved.isPresent()).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
device = generateDevice(1);
device = generateDevice(DEVICE_ID_1);
Account unknownAccount = generateAccount("+14151113333", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
@ -452,10 +456,10 @@ class AccountsTest {
@Test
void testDelete() {
final Device deletedDevice = generateDevice(1);
final Device deletedDevice = generateDevice(DEVICE_ID_1);
final Account deletedAccount = generateAccount("+14151112222", UUID.randomUUID(),
UUID.randomUUID(), List.of(deletedDevice));
final Device retainedDevice = generateDevice(1);
final Device retainedDevice = generateDevice(DEVICE_ID_1);
final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(),
UUID.randomUUID(), List.of(retainedDevice));
@ -485,7 +489,7 @@ class AccountsTest {
{
final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(),
UUID.randomUUID(), List.of(generateDevice(1)));
UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
final boolean freshUser = accounts.create(recreatedAccount);
@ -501,7 +505,7 @@ class AccountsTest {
@Test
void testMissing() {
Device device = generateDevice (1 );
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@ -518,7 +522,7 @@ class AccountsTest {
assertThat(accounts.getByAccountIdentifierAsync(UUID.randomUUID()).join()).isEmpty();
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@ -530,7 +534,7 @@ class AccountsTest {
assertThat(accounts.getByPhoneNumberIdentifierAsync(UUID.randomUUID()).join()).isEmpty();
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@ -544,7 +548,7 @@ class AccountsTest {
assertThat(accounts.getByE164Async(e164).join()).isEmpty();
final Account account =
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(1)));
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
@ -553,7 +557,7 @@ class AccountsTest {
@Test
void testCanonicallyDiscoverableSet() {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
account.setDiscoverableByPhoneNumber(false);
accounts.create(account);
@ -576,7 +580,7 @@ class AccountsTest {
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
@ -631,10 +635,10 @@ class AccountsTest {
final UUID originalPni = UUID.randomUUID();
final UUID targetPni = UUID.randomUUID();
final Device existingDevice = generateDevice(1);
final Device existingDevice = generateDevice(DEVICE_ID_1);
final Account existingAccount = generateAccount(targetNumber, UUID.randomUUID(), targetPni, List.of(existingDevice));
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
@ -653,7 +657,7 @@ class AccountsTest {
final String originalNumber = "+14151112222";
final String targetNumber = "+14151113333";
final Device device = generateDevice(1);
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
@ -969,7 +973,48 @@ class AccountsTest {
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isPresent();
}
private static Device generateDevice(long id) {
@Test
public void testInvalidDeviceIdDeserialization() throws Exception {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
final Device device2 = generateDevice((byte) 64);
account.addDevice(device2);
accounts.create(account);
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.build()).join();
final Map<?, ?> accountData = SystemMapper.jsonMapper()
.readValue(response.item().get(Accounts.ATTR_ACCOUNT_DATA).b().asByteArray(), Map.class);
final List<Map<Object, Object>> devices = (List<Map<Object, Object>>) accountData.get("devices");
assertEquals(Integer.valueOf(device2.getId()), devices.get(1).get("id"));
devices.get(1).put("id", Byte.MAX_VALUE + 5);
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().updateItem(UpdateItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression("SET #data = :data")
.expressionAttributeNames(Map.of("#data", Accounts.ATTR_ACCOUNT_DATA))
.expressionAttributeValues(
Map.of(":data", AttributeValues.fromByteArray(SystemMapper.jsonMapper().writeValueAsBytes(accountData))))
.build()).join();
final CompletionException e = assertThrows(CompletionException.class,
() -> accounts.getByAccountIdentifierAsync(account.getUuid()).join());
Throwable cause = e.getCause();
while (cause.getCause() != null) {
cause = cause.getCause();
}
assertInstanceOf(DeviceIdDeserializer.DeviceIdDeserializationException.class, cause);
}
private static Device generateDevice(byte id) {
return DevicesHelper.createDevice(id);
}
@ -979,7 +1024,7 @@ class AccountsTest {
}
private static Account generateAccount(String number, UUID uuid, final UUID pni) {
Device device = generateDevice(1);
Device device = generateDevice(DEVICE_ID_1);
return generateAccount(number, uuid, pni, List.of(device));
}

View File

@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -68,7 +69,7 @@ public class ChangeNumberManagerTest {
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(updatedPni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@ -87,7 +88,7 @@ public class ChangeNumberManagerTest {
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(pni);
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
@ -102,7 +103,7 @@ public class ChangeNumberManagerTest {
when(account.getNumber()).thenReturn("+18005551234");
changeNumberManager.changeNumber(account, "+18025551234", null, null, null, null, null);
verify(accountsManager).changeNumber(account, "+18025551234", null, null, null, null);
verify(accountsManager, never()).updateDevice(any(), eq(1L), any());
verify(accountsManager, never()).updateDevice(any(), anyByte(), any());
verify(messageSender, never()).sendMessage(eq(account), any(), any(), eq(false));
}
@ -112,7 +113,8 @@ public class ChangeNumberManagerTest {
when(account.getNumber()).thenReturn("+18005551234");
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
changeNumberManager.changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyList(), Collections.emptyMap());
verify(accountsManager).changeNumber(account, "+18025551234", pniIdentityKey, prekeys, null, Collections.emptyMap());
@ -133,18 +135,21 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@ -177,19 +182,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, changedE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -220,19 +229,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.changeNumber(account, originalE164, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -261,18 +274,21 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, null, List.of(msg), registrationIds);
@ -301,19 +317,23 @@ public class ChangeNumberManagerTest {
final Device d2 = mock(Device.class);
when(d2.isEnabled()).thenReturn(true);
when(d2.getId()).thenReturn(2L);
final byte deviceId2 = 2;
when(d2.getId()).thenReturn(deviceId2);
when(account.getDevice(2L)).thenReturn(Optional.of(d2));
when(account.getDevice(deviceId2)).thenReturn(Optional.of(d2));
when(account.getDevices()).thenReturn(List.of(d2));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final Map<Long, ECSignedPreKey> prekeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Long, KEMSignedPreKey> pqPrekeys = Map.of(3L, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair), 4L, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 19);
final Map<Byte, ECSignedPreKey> prekeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
deviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair));
final Map<Byte, KEMSignedPreKey> pqPrekeys = Map.of((byte) 3, KeysHelper.signedKEMPreKey(3, pniIdentityKeyPair),
(byte) 4, KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, deviceId2, 19);
final IncomingMessage msg = mock(IncomingMessage.class);
when(msg.destinationDeviceId()).thenReturn(2L);
when(msg.destinationDeviceId()).thenReturn(deviceId2);
when(msg.content()).thenReturn(Base64.getEncoder().encodeToString(new byte[]{1}));
changeNumberManager.updatePniKeys(account, pniIdentityKey, prekeys, pqPrekeys, List.of(msg), registrationIds);
@ -338,11 +358,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@ -350,15 +370,21 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
@ -371,11 +397,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@ -383,15 +409,21 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 1, "foo"),
new IncomingMessage(1, 3, 1, "foo"));
new IncomingMessage(1, destinationDeviceId2, 1, "foo"),
new IncomingMessage(1, destinationDeviceId3, 1, "foo"));
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECPublicKey pniIdentityKey = pniIdentityKeyPair.getPublicKey();
final Map<Long, ECSignedPreKey> preKeys = Map.of(1L, KeysHelper.signedECPreKey(1, pniIdentityKeyPair), 2L, KeysHelper.signedECPreKey(2, pniIdentityKeyPair), 3L, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID,
KeysHelper.signedECPreKey(1, pniIdentityKeyPair),
destinationDeviceId2, KeysHelper.signedECPreKey(2, pniIdentityKeyPair),
destinationDeviceId3, KeysHelper.signedECPreKey(3, pniIdentityKeyPair));
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, 17, destinationDeviceId2, 47,
destinationDeviceId3, 89);
assertThrows(StaleDevicesException.class,
() -> changeNumberManager.updatePniKeys(account, new IdentityKey(Curve.generateKeyPair().getPublicKey()), preKeys, null, messages, registrationIds));
@ -404,11 +436,11 @@ public class ChangeNumberManagerTest {
final List<Device> devices = new ArrayList<>();
for (int i = 1; i <= 3; i++) {
for (byte i = 1; i <= 3; i++) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((long) i);
when(device.getId()).thenReturn(i);
when(device.isEnabled()).thenReturn(true);
when(device.getRegistrationId()).thenReturn(i);
when(device.getRegistrationId()).thenReturn((int) i);
devices.add(device);
when(account.getDevice(i)).thenReturn(Optional.of(device));
@ -416,11 +448,13 @@ public class ChangeNumberManagerTest {
when(account.getDevices()).thenReturn(devices);
final byte destinationDeviceId2 = 2;
final byte destinationDeviceId3 = 3;
final List<IncomingMessage> messages = List.of(
new IncomingMessage(1, 2, 2, "foo"),
new IncomingMessage(1, 3, 3, "foo"));
new IncomingMessage(1, destinationDeviceId2, 2, "foo"),
new IncomingMessage(1, destinationDeviceId3, 3, "foo"));
final Map<Long, Integer> registrationIds = Map.of(1L, 17, 2L, 47, 3L, 89);
final Map<Byte, Integer> registrationIds = Map.of((byte) 1, 17, destinationDeviceId2, 47, destinationDeviceId3, 89);
assertThrows(IllegalArgumentException.class,
() -> changeNumberManager.changeNumber(account, "+18005559876", new IdentityKey(Curve.generateKeyPair().getPublicKey()), null, null, messages, registrationIds));

View File

@ -40,7 +40,7 @@ class KeysManagerTest {
Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final long DEVICE_ID = 1L;
private static final byte DEVICE_ID = 1;
private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair();
@ -169,7 +169,8 @@ class KeysManagerTest {
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
@ -180,10 +181,10 @@ class KeysManagerTest {
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
keysManager.delete(ACCOUNT_UUID).join();
@ -191,10 +192,10 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
}
@Test
@ -206,7 +207,8 @@ class KeysManagerTest {
generateTestKEMSignedPreKey(6))
.join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1,
final byte deviceId2 = DEVICE_ID + 1;
keysManager.store(ACCOUNT_UUID, deviceId2,
List.of(generateTestPreKey(7)),
List.of(generateTestKEMSignedPreKey(8)),
generateTestECSignedPreKey(9),
@ -217,10 +219,10 @@ class KeysManagerTest {
assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join();
@ -228,10 +230,10 @@ class KeysManagerTest {
assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join());
assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).join().isPresent());
assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join());
assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join());
assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent());
assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent());
}
@Test
@ -240,21 +242,29 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).join().isPresent());
final byte deviceId2 = 2;
final byte deviceId3 = 3;
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(1, identityKeyPair), (byte) 2,
KeysHelper.signedKEMPreKey(2, identityKeyPair))).join();
assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size());
assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId());
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId());
assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().isPresent());
keysManager.storePqLastResort(
ACCOUNT_UUID,
Map.of(DEVICE_ID, KeysHelper.signedKEMPreKey(3, identityKeyPair), deviceId3,
KeysHelper.signedKEMPreKey(4, identityKeyPair))).join();
assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).join().size(), "storing new last-resort keys should not create duplicates");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).join().get().keyId(), "storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).join().get().keyId(), "storing new last-resort keys should overwrite old ones");
assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().get().keyId(),
"storing new last-resort keys should leave untouched ones alone");
assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, deviceId3).join().get().keyId(),
"storing new last-resort keys should overwrite old ones");
}
@Test
@ -262,11 +272,14 @@ class KeysManagerTest {
final ECKeyPair identityKeyPair = Curve.generateKeyPair();
keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null, null).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), null, null, null,
KeysHelper.signedKEMPreKey(2, identityKeyPair)).join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), null,
List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair))
.join();
keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 3), null, null, null, null).join();
assertIterableEquals(
Set.of(DEVICE_ID + 1, DEVICE_ID + 2),
Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)),
Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join()));
}

View File

@ -124,17 +124,19 @@ class MessagePersisterIntegrationTest {
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, timestamp);
messagesCache.insert(messageGuid, account.getUuid(), 1, message);
messagesCache.insert(messageGuid, account.getUuid(), Device.PRIMARY_ID, message);
expectedMessages.add(message);
}
REDIS_CLUSTER_EXTENSION.getRedisCluster()
.useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY,
String.valueOf(SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), 1)) - 1)));
String.valueOf(
SlotHash.getSlot(MessagesCache.getMessageQueueKey(account.getUuid(), Device.PRIMARY_ID)) - 1)));
final AtomicBoolean messagesPersisted = new AtomicBoolean(false);
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
messagesManager.addMessageAvailabilityListener(account.getUuid(), Device.PRIMARY_ID,
new MessageAvailabilityListener() {
@Override
public boolean handleNewMessagesAvailable() {
return true;

View File

@ -9,8 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
@ -61,7 +61,7 @@ class MessagePersisterTest {
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
private static final byte DESTINATION_DEVICE_ID = 7;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
@ -90,9 +90,9 @@ class MessagePersisterTest {
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY, 1);
doAnswer(invocation -> {
when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> {
final UUID destinationUuid = invocation.getArgument(0);
final long destinationDeviceId = invocation.getArgument(1);
final byte destinationDeviceId = invocation.getArgument(1);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2);
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
@ -101,8 +101,8 @@ class MessagePersisterTest {
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())).get();
}
return null;
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any());
return messages.size();
});
}
@AfterEach
@ -153,7 +153,7 @@ class MessagePersisterTest {
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
verify(messagesDynamoDb, never()).store(any(), any(), anyByte());
}
@Test
@ -166,7 +166,7 @@ class MessagePersisterTest {
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
@ -183,7 +183,7 @@ class MessagePersisterTest {
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyByte());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@ -219,7 +219,7 @@ class MessagePersisterTest {
setNextSlotToPersist(SlotHash.getSlot(queueName));
// returning `0` indicates something not working correctly
when(messagesManager.persistMessages(any(UUID.class), anyLong(), anyList())).thenReturn(0);
when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenReturn(0);
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
assertThrows(MessagePersistenceException.class,
@ -228,22 +228,23 @@ class MessagePersisterTest {
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
while (true) {
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
for (byte deviceId = 1; deviceId < Device.MAXIMUM_DEVICE_ID; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
}
}
throw new IllegalStateException("Could not find a queue name for slot " + slot);
}
private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount,
private void insertMessages(final UUID accountUuid, final byte deviceId, final int messageCount,
final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();

View File

@ -85,7 +85,7 @@ class MessagesCacheTest {
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7;
private static final byte DESTINATION_DEVICE_ID = 7;
@BeforeEach
void setUp() throws Exception {
@ -311,7 +311,7 @@ class MessagesCacheTest {
void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
@ -323,7 +323,7 @@ class MessagesCacheTest {
messagesCache.clear(DESTINATION_UUID, DESTINATION_DEVICE_ID).join();
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(messageCount, get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
assertEquals(messageCount, get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount).size());
}
@ParameterizedTest
@ -331,7 +331,7 @@ class MessagesCacheTest {
void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (final byte deviceId : new byte[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
@ -343,7 +343,7 @@ class MessagesCacheTest {
messagesCache.clear(DESTINATION_UUID).join();
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount));
assertEquals(Collections.emptyList(), get(DESTINATION_UUID, (byte) (DESTINATION_DEVICE_ID + 1), messageCount));
}
@Test
@ -531,7 +531,7 @@ class MessagesCacheTest {
});
}
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final long destinationDeviceId,
private List<MessageProtos.Envelope> get(final UUID destinationUuid, final byte destinationDeviceId,
final int messageCount) {
return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
.take(messageCount, true)
@ -605,7 +605,7 @@ class MessagesCacheTest {
.thenReturn(Flux.from(emptyFinalPagePublisher))
.thenReturn(Flux.empty());
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), 1L);
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID);
// Why initialValue = 3?
// 1. messagesCache.getAllMessages() above produces the first call
@ -691,7 +691,7 @@ class MessagesCacheTest {
when(asyncCommands.evalsha(any(), any(), any(), any()))
.thenReturn((RedisFuture) removeSuccess);
final Publisher<?> allMessages = messagesCache.get(UUID.randomUUID(), 1L);
final Publisher<?> allMessages = messagesCache.get(UUID.randomUUID(), Device.PRIMARY_ID);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(5));
@ -752,7 +752,7 @@ class MessagesCacheTest {
.setDestinationUuid(UUID.randomUUID().toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(256))
envelopeBuilder.setSourceDevice(random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1)
.setSourceUuid(UUID.randomUUID().toString());
}

View File

@ -98,7 +98,7 @@ class MessagesDynamoDbTest {
@Test
void testSimpleFetchAfterInsert() {
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
final List<MessageProtos.Envelope> messagesStored = load(destinationUuid, destinationDeviceId,
@ -116,11 +116,12 @@ class MessagesDynamoDbTest {
@ValueSource(ints = {10, 100, 100, 1_000, 3_000})
void testLoadManyAfterInsert(final int messageCount) {
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000,
"message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
@ -148,18 +149,20 @@ class MessagesDynamoDbTest {
void testLimitedLoad() {
final int messageCount = 200;
final UUID destinationUuid = UUID.randomUUID();
final int destinationDeviceId = random.nextInt(255) + 1;
final byte destinationDeviceId = (byte) (random.nextInt(Device.MAXIMUM_DEVICE_ID) + 1);
final List<MessageProtos.Envelope> messages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
messages.add(MessageHelper.createMessage(UUID.randomUUID(), Device.PRIMARY_ID, destinationUuid, (i + 1L) * 1000,
"message " + i));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
final int messageLoadLimit = 100;
final int halfOfMessageLoadLimit = messageLoadLimit / 2;
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit);
final Publisher<?> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
messageLoadLimit);
StepVerifier.setDefaultTimeout(Duration.ofSeconds(10));
@ -170,7 +173,7 @@ class MessagesDynamoDbTest {
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// the first 100 should be fetched and buffered, but further requests should fail
.then(() -> DYNAMO_DB_EXTENSION.stopServer())
.then(DYNAMO_DB_EXTENSION::stopServer)
.thenRequest(halfOfMessageLoadLimit)
.expectNextCount(halfOfMessageLoadLimit)
// weve consumed all the buffered messages, so a single request will fail
@ -183,22 +186,23 @@ class MessagesDynamoDbTest {
void testDeleteForDestination() {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte deviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, deviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, (byte) 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid).join();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(destinationUuid, deviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@ -206,23 +210,26 @@ class MessagesDynamoDbTest {
void testDeleteForDestinationDevice() {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2).join();
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, destinationDeviceId2).join();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@ -230,15 +237,17 @@ class MessagesDynamoDbTest {
void testDeleteMessageByDestinationAndGuid() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
final Optional<MessageProtos.Envelope> deletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid(
@ -247,11 +256,12 @@ class MessagesDynamoDbTest {
assertThat(deletedMessage).isPresent();
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
final Optional<MessageProtos.Envelope> alreadyDeletedMessage = messagesDynamoDb.deleteMessageByDestinationAndGuid(
@ -266,29 +276,32 @@ class MessagesDynamoDbTest {
void testDeleteSingleMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, Device.PRIMARY_ID);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, Device.PRIMARY_ID);
final byte destinationDeviceId2 = 2;
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, destinationDeviceId2);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
messagesDynamoDb.deleteMessage(secondDestinationUuid, Device.PRIMARY_ID,
UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS);
assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
assertThat(load(destinationUuid, destinationDeviceId2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1)
.element(0).isEqualTo(MESSAGE3);
assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
assertThat(load(secondDestinationUuid, Device.PRIMARY_ID, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
private List<MessageProtos.Envelope> load(final UUID destinationUuid, final long destinationDeviceId,
private List<MessageProtos.Envelope> load(final UUID destinationUuid, final byte destinationDeviceId,
final int count) {
return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count))
.take(count, true)

Some files were not shown because too many files have changed in this diff Show More