diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java index eccd1e97d..aecba4ac4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/MultiRecipientMessage.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.entities; +import java.util.Arrays; import java.util.UUID; import javax.validation.Valid; import javax.validation.constraints.Max; @@ -53,6 +54,37 @@ public class MultiRecipientMessage { public byte[] getPerRecipientKeyMaterial() { return perRecipientKeyMaterial; } + + @Override + public boolean equals(final Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + Recipient recipient = (Recipient) o; + + if (deviceId != recipient.deviceId) + return false; + if (registrationId != recipient.registrationId) + return false; + if (!uuid.equals(recipient.uuid)) + return false; + return Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial); + } + + @Override + public int hashCode() { + int result = uuid.hashCode(); + result = 31 * result + (int) (deviceId ^ (deviceId >>> 32)); + result = 31 * result + registrationId; + result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial); + return result; + } + + public String toString() { + return "Recipient(" + uuid + ", " + deviceId + ", " + registrationId + ", " + Arrays.toString(perRecipientKeyMaterial) + ")"; + } } @NotNull diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java index a19cf087b..579f98eac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/providers/MultiRecipientMessageProvider.java @@ -107,10 +107,11 @@ public class MultiRecipientMessageProvider implements MessageBodyReader= 64) { throw new BadRequestException("varint is too large"); @@ -123,7 +124,7 @@ public class MultiRecipientMessageProvider implements MessageBodyReader>> 7; + if (x != 0) b |= 0x80; + bb.put((byte)b); } while (x != 0); - bb.putShort((short) regId); // registration id short - bb.put(new byte[48]); // key material (48 bytes) } - private InputStream initializeMultiPayload(List recipients, byte[] buffer) throws Exception { + private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception { + long msb = r.getUuid().getMostSignificantBits(); + long lsb = r.getUuid().getLeastSignificantBits(); + bb.putLong(msb); // uuid (first 8 bytes) + bb.putLong(lsb); // uuid (last 8 bytes) + writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes) + bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes) + bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes) + } + + private static InputStream initializeMultiPayload(List recipients, byte[] buffer) throws Exception { // initialize a binary payload according to our wire format ByteBuffer bb = ByteBuffer.wrap(buffer); bb.order(ByteOrder.BIG_ENDIAN); @@ -743,11 +754,7 @@ class MessageControllerTest { Iterator it = recipients.iterator(); while (it.hasNext()) { - Recipient r = it.next(); - UUID uuid = r.getUuid(); - long msb = uuid.getMostSignificantBits(); - long lsb = uuid.getLeastSignificantBits(); - writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId()); + writeMultiPayloadRecipient(bb, it.next()); } // now write the actual message body (empty for now) @@ -1003,4 +1010,62 @@ class MessageControllerTest { return builder.build(); } + + private static Recipient genRecipient(Random rng) { + UUID u1 = UUID.randomUUID(); // non-null + long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903 + int dr1 = rng.nextInt() & 0xffff; // 0 to 65535 + byte[] perKeyBytes = new byte[48]; // size=48, non-null + rng.nextBytes(perKeyBytes); + return new Recipient(u1, d1, dr1, perKeyBytes); + } + + private static void roundTripVarint(long expected, byte [] bytes) throws Exception { + ByteBuffer bb = ByteBuffer.wrap(bytes); + writePayloadDeviceId(bb, expected); + InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position()); + long got = MultiRecipientMessageProvider.readVarint(stream); + assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes))); + } + + @Test + void testVarintPayload() throws Exception { + Random rng = new Random(); + byte[] bytes = new byte[12]; + + // some static test cases + for (long i = 1L; i <= 10L; i++) { + roundTripVarint(i, bytes); + } + roundTripVarint(Long.MAX_VALUE, bytes); + + for (int i = 0; i < 1000; i++) { + // we need to ensure positive device IDs + long start = rng.nextLong() & Long.MAX_VALUE; + if (start == 0L) start = 1L; + + // run the test for this case + roundTripVarint(start, bytes); + } + } + + @Test + void testMultiPayloadRoundtrip() throws Exception { + Random rng = new java.util.Random(); + List expected = new LinkedList<>(); + for(int i = 0; i < 100; i++) { + expected.add(genRecipient(rng)); + } + + byte[] buffer = new byte[100 + expected.size() * 100]; + InputStream entityStream = initializeMultiPayload(expected, buffer); + MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider(); + // the provider ignores the headers, java reflection, etc. so we don't use those here. + MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream); + List got = Arrays.asList(res.getRecipients()); + + assertEquals(expected, got); + } + + }