Improve testing of MultiRecipientMessageProvider
This commit is contained in:
parent
378d7987a8
commit
a7d5d51fb4
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
import javax.validation.constraints.Max;
|
import javax.validation.constraints.Max;
|
||||||
|
@ -53,6 +54,37 @@ public class MultiRecipientMessage {
|
||||||
public byte[] getPerRecipientKeyMaterial() {
|
public byte[] getPerRecipientKeyMaterial() {
|
||||||
return perRecipientKeyMaterial;
|
return perRecipientKeyMaterial;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(final Object o) {
|
||||||
|
if (this == o)
|
||||||
|
return true;
|
||||||
|
if (o == null || getClass() != o.getClass())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
Recipient recipient = (Recipient) o;
|
||||||
|
|
||||||
|
if (deviceId != recipient.deviceId)
|
||||||
|
return false;
|
||||||
|
if (registrationId != recipient.registrationId)
|
||||||
|
return false;
|
||||||
|
if (!uuid.equals(recipient.uuid))
|
||||||
|
return false;
|
||||||
|
return Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
int result = uuid.hashCode();
|
||||||
|
result = 31 * result + (int) (deviceId ^ (deviceId >>> 32));
|
||||||
|
result = 31 * result + registrationId;
|
||||||
|
result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toString() {
|
||||||
|
return "Recipient(" + uuid + ", " + deviceId + ", " + registrationId + ", " + Arrays.toString(perRecipientKeyMaterial) + ")";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
|
|
|
@ -107,10 +107,11 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
|
||||||
*
|
*
|
||||||
* @return the varint value
|
* @return the varint value
|
||||||
*/
|
*/
|
||||||
private long readVarint(InputStream stream) throws IOException, WebApplicationException {
|
@VisibleForTesting
|
||||||
|
public static long readVarint(InputStream stream) throws IOException, WebApplicationException {
|
||||||
boolean hasMore = true;
|
boolean hasMore = true;
|
||||||
int currentOffset = 0;
|
int currentOffset = 0;
|
||||||
int result = 0;
|
long result = 0;
|
||||||
while (hasMore) {
|
while (hasMore) {
|
||||||
if (currentOffset >= 64) {
|
if (currentOffset >= 64) {
|
||||||
throw new BadRequestException("varint is too large");
|
throw new BadRequestException("varint is too large");
|
||||||
|
@ -123,7 +124,7 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
|
||||||
throw new BadRequestException("varint is too large");
|
throw new BadRequestException("varint is too large");
|
||||||
}
|
}
|
||||||
hasMore = (b & 0x80) != 0;
|
hasMore = (b & 0x80) != 0;
|
||||||
result |= (b & 0x7F) << currentOffset;
|
result |= (b & 0x7FL) << currentOffset;
|
||||||
currentOffset += 7;
|
currentOffset += 7;
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -40,8 +40,10 @@ import java.nio.ByteOrder;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import java.util.Random;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.Callable;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
|
@ -68,6 +70,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
|
||||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
|
||||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
|
||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
|
||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
|
||||||
|
@ -719,20 +722,28 @@ class MessageControllerTest {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception {
|
private static void writePayloadDeviceId(ByteBuffer bb, long deviceId) {
|
||||||
bb.putLong(msb); // uuid (first 8 bytes)
|
|
||||||
bb.putLong(lsb); // uuid (last 8 bytes)
|
|
||||||
long x = deviceId;
|
long x = deviceId;
|
||||||
// write the device-id in the 7-bit varint format we use, least significant bytes first.
|
// write the device-id in the 7-bit varint format we use, least significant bytes first.
|
||||||
do {
|
do {
|
||||||
bb.put((byte)(x & 0x7f));
|
long b = x & 0x7f;
|
||||||
x = x >>> 7;
|
x = x >>> 7;
|
||||||
|
if (x != 0) b |= 0x80;
|
||||||
|
bb.put((byte)b);
|
||||||
} while (x != 0);
|
} while (x != 0);
|
||||||
bb.putShort((short) regId); // registration id short
|
|
||||||
bb.put(new byte[48]); // key material (48 bytes)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
|
private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception {
|
||||||
|
long msb = r.getUuid().getMostSignificantBits();
|
||||||
|
long lsb = r.getUuid().getLeastSignificantBits();
|
||||||
|
bb.putLong(msb); // uuid (first 8 bytes)
|
||||||
|
bb.putLong(lsb); // uuid (last 8 bytes)
|
||||||
|
writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes)
|
||||||
|
bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes)
|
||||||
|
bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
|
||||||
// initialize a binary payload according to our wire format
|
// initialize a binary payload according to our wire format
|
||||||
ByteBuffer bb = ByteBuffer.wrap(buffer);
|
ByteBuffer bb = ByteBuffer.wrap(buffer);
|
||||||
bb.order(ByteOrder.BIG_ENDIAN);
|
bb.order(ByteOrder.BIG_ENDIAN);
|
||||||
|
@ -743,11 +754,7 @@ class MessageControllerTest {
|
||||||
|
|
||||||
Iterator<Recipient> it = recipients.iterator();
|
Iterator<Recipient> it = recipients.iterator();
|
||||||
while (it.hasNext()) {
|
while (it.hasNext()) {
|
||||||
Recipient r = it.next();
|
writeMultiPayloadRecipient(bb, it.next());
|
||||||
UUID uuid = r.getUuid();
|
|
||||||
long msb = uuid.getMostSignificantBits();
|
|
||||||
long lsb = uuid.getLeastSignificantBits();
|
|
||||||
writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// now write the actual message body (empty for now)
|
// now write the actual message body (empty for now)
|
||||||
|
@ -1003,4 +1010,62 @@ class MessageControllerTest {
|
||||||
|
|
||||||
return builder.build();
|
return builder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static Recipient genRecipient(Random rng) {
|
||||||
|
UUID u1 = UUID.randomUUID(); // non-null
|
||||||
|
long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903
|
||||||
|
int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
|
||||||
|
byte[] perKeyBytes = new byte[48]; // size=48, non-null
|
||||||
|
rng.nextBytes(perKeyBytes);
|
||||||
|
return new Recipient(u1, d1, dr1, perKeyBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
|
||||||
|
ByteBuffer bb = ByteBuffer.wrap(bytes);
|
||||||
|
writePayloadDeviceId(bb, expected);
|
||||||
|
InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
|
||||||
|
long got = MultiRecipientMessageProvider.readVarint(stream);
|
||||||
|
assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testVarintPayload() throws Exception {
|
||||||
|
Random rng = new Random();
|
||||||
|
byte[] bytes = new byte[12];
|
||||||
|
|
||||||
|
// some static test cases
|
||||||
|
for (long i = 1L; i <= 10L; i++) {
|
||||||
|
roundTripVarint(i, bytes);
|
||||||
|
}
|
||||||
|
roundTripVarint(Long.MAX_VALUE, bytes);
|
||||||
|
|
||||||
|
for (int i = 0; i < 1000; i++) {
|
||||||
|
// we need to ensure positive device IDs
|
||||||
|
long start = rng.nextLong() & Long.MAX_VALUE;
|
||||||
|
if (start == 0L) start = 1L;
|
||||||
|
|
||||||
|
// run the test for this case
|
||||||
|
roundTripVarint(start, bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testMultiPayloadRoundtrip() throws Exception {
|
||||||
|
Random rng = new java.util.Random();
|
||||||
|
List<Recipient> expected = new LinkedList<>();
|
||||||
|
for(int i = 0; i < 100; i++) {
|
||||||
|
expected.add(genRecipient(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] buffer = new byte[100 + expected.size() * 100];
|
||||||
|
InputStream entityStream = initializeMultiPayload(expected, buffer);
|
||||||
|
MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
|
||||||
|
// the provider ignores the headers, java reflection, etc. so we don't use those here.
|
||||||
|
MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
|
||||||
|
List<Recipient> got = Arrays.asList(res.getRecipients());
|
||||||
|
|
||||||
|
assertEquals(expected, got);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue