Improve testing of MultiRecipientMessageProvider
This commit is contained in:
		
							parent
							
								
									378d7987a8
								
							
						
					
					
						commit
						a7d5d51fb4
					
				| 
						 | 
					@ -5,6 +5,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package org.whispersystems.textsecuregcm.entities;
 | 
					package org.whispersystems.textsecuregcm.entities;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.UUID;
 | 
					import java.util.UUID;
 | 
				
			||||||
import javax.validation.Valid;
 | 
					import javax.validation.Valid;
 | 
				
			||||||
import javax.validation.constraints.Max;
 | 
					import javax.validation.constraints.Max;
 | 
				
			||||||
| 
						 | 
					@ -53,6 +54,37 @@ public class MultiRecipientMessage {
 | 
				
			||||||
    public byte[] getPerRecipientKeyMaterial() {
 | 
					    public byte[] getPerRecipientKeyMaterial() {
 | 
				
			||||||
      return perRecipientKeyMaterial;
 | 
					      return perRecipientKeyMaterial;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public boolean equals(final Object o) {
 | 
				
			||||||
 | 
					      if (this == o)
 | 
				
			||||||
 | 
					        return true;
 | 
				
			||||||
 | 
					      if (o == null || getClass() != o.getClass())
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Recipient recipient = (Recipient) o;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (deviceId != recipient.deviceId)
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					      if (registrationId != recipient.registrationId)
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					      if (!uuid.equals(recipient.uuid))
 | 
				
			||||||
 | 
					        return false;
 | 
				
			||||||
 | 
					      return Arrays.equals(perRecipientKeyMaterial, recipient.perRecipientKeyMaterial);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @Override
 | 
				
			||||||
 | 
					    public int hashCode() {
 | 
				
			||||||
 | 
					      int result = uuid.hashCode();
 | 
				
			||||||
 | 
					      result = 31 * result + (int) (deviceId ^ (deviceId >>> 32));
 | 
				
			||||||
 | 
					      result = 31 * result + registrationId;
 | 
				
			||||||
 | 
					      result = 31 * result + Arrays.hashCode(perRecipientKeyMaterial);
 | 
				
			||||||
 | 
					      return result;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    public String toString() {
 | 
				
			||||||
 | 
					      return "Recipient(" + uuid + ", " + deviceId + ", " + registrationId + ", " + Arrays.toString(perRecipientKeyMaterial) + ")";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @NotNull
 | 
					  @NotNull
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -107,10 +107,11 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @return the varint value
 | 
					   * @return the varint value
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  private long readVarint(InputStream stream) throws IOException, WebApplicationException {
 | 
					  @VisibleForTesting
 | 
				
			||||||
 | 
					  public static long readVarint(InputStream stream) throws IOException, WebApplicationException {
 | 
				
			||||||
    boolean hasMore = true;
 | 
					    boolean hasMore = true;
 | 
				
			||||||
    int currentOffset = 0;
 | 
					    int currentOffset = 0;
 | 
				
			||||||
    int result = 0;
 | 
					    long result = 0;
 | 
				
			||||||
    while (hasMore) {
 | 
					    while (hasMore) {
 | 
				
			||||||
      if (currentOffset >= 64) {
 | 
					      if (currentOffset >= 64) {
 | 
				
			||||||
        throw new BadRequestException("varint is too large");
 | 
					        throw new BadRequestException("varint is too large");
 | 
				
			||||||
| 
						 | 
					@ -123,7 +124,7 @@ public class MultiRecipientMessageProvider implements MessageBodyReader<MultiRec
 | 
				
			||||||
        throw new BadRequestException("varint is too large");
 | 
					        throw new BadRequestException("varint is too large");
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      hasMore = (b & 0x80) != 0;
 | 
					      hasMore = (b & 0x80) != 0;
 | 
				
			||||||
      result |= (b & 0x7F) << currentOffset;
 | 
					      result |= (b & 0x7FL) << currentOffset;
 | 
				
			||||||
      currentOffset += 7;
 | 
					      currentOffset += 7;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return result;
 | 
					    return result;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,8 +40,10 @@ import java.nio.ByteOrder;
 | 
				
			||||||
import java.util.Arrays;
 | 
					import java.util.Arrays;
 | 
				
			||||||
import java.util.Base64;
 | 
					import java.util.Base64;
 | 
				
			||||||
import java.util.Iterator;
 | 
					import java.util.Iterator;
 | 
				
			||||||
 | 
					import java.util.LinkedList;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
import java.util.Optional;
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					import java.util.Random;
 | 
				
			||||||
import java.util.UUID;
 | 
					import java.util.UUID;
 | 
				
			||||||
import java.util.concurrent.Callable;
 | 
					import java.util.concurrent.Callable;
 | 
				
			||||||
import java.util.concurrent.ExecutorService;
 | 
					import java.util.concurrent.ExecutorService;
 | 
				
			||||||
| 
						 | 
					@ -68,6 +70,7 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
 | 
					import org.whispersystems.textsecuregcm.entities.MessageProtos;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
 | 
					import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
 | 
					import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
 | 
				
			||||||
 | 
					import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
 | 
					import org.whispersystems.textsecuregcm.entities.MultiRecipientMessage.Recipient;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
 | 
					import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
 | 
					import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
 | 
				
			||||||
| 
						 | 
					@ -719,20 +722,28 @@ class MessageControllerTest {
 | 
				
			||||||
    );
 | 
					    );
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private void writeMultiPayloadRecipient(ByteBuffer bb, long msb, long lsb, long deviceId, int regId) throws Exception {
 | 
					  private static void writePayloadDeviceId(ByteBuffer bb, long deviceId) {
 | 
				
			||||||
    bb.putLong(msb);            // uuid (first 8 bytes)
 | 
					 | 
				
			||||||
    bb.putLong(lsb);            // uuid (last 8 bytes)
 | 
					 | 
				
			||||||
    long x = deviceId;
 | 
					    long x = deviceId;
 | 
				
			||||||
    // write the device-id in the 7-bit varint format we use, least significant bytes first.
 | 
					    // write the device-id in the 7-bit varint format we use, least significant bytes first.
 | 
				
			||||||
    do {
 | 
					    do {
 | 
				
			||||||
      bb.put((byte)(x & 0x7f));
 | 
					      long b = x & 0x7f;
 | 
				
			||||||
      x = x >>> 7;
 | 
					      x = x >>> 7;
 | 
				
			||||||
 | 
					      if (x != 0) b |= 0x80;
 | 
				
			||||||
 | 
					      bb.put((byte)b);
 | 
				
			||||||
    } while (x != 0);
 | 
					    } while (x != 0);
 | 
				
			||||||
    bb.putShort((short) regId); // registration id short
 | 
					 | 
				
			||||||
    bb.put(new byte[48]);       // key material (48 bytes)
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
 | 
					  private static void writeMultiPayloadRecipient(ByteBuffer bb, Recipient r) throws Exception {
 | 
				
			||||||
 | 
					    long msb = r.getUuid().getMostSignificantBits();
 | 
				
			||||||
 | 
					    long lsb = r.getUuid().getLeastSignificantBits();
 | 
				
			||||||
 | 
					    bb.putLong(msb);            // uuid (first 8 bytes)
 | 
				
			||||||
 | 
					    bb.putLong(lsb);            // uuid (last 8 bytes)
 | 
				
			||||||
 | 
					    writePayloadDeviceId(bb, r.getDeviceId()); // device id (1-9 bytes)
 | 
				
			||||||
 | 
					    bb.putShort((short) r.getRegistrationId()); // registration id (2 bytes)
 | 
				
			||||||
 | 
					    bb.put(r.getPerRecipientKeyMaterial()); // key material (48 bytes)
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static InputStream initializeMultiPayload(List<Recipient> recipients, byte[] buffer) throws Exception {
 | 
				
			||||||
    // initialize a binary payload according to our wire format
 | 
					    // initialize a binary payload according to our wire format
 | 
				
			||||||
    ByteBuffer bb = ByteBuffer.wrap(buffer);
 | 
					    ByteBuffer bb = ByteBuffer.wrap(buffer);
 | 
				
			||||||
    bb.order(ByteOrder.BIG_ENDIAN);
 | 
					    bb.order(ByteOrder.BIG_ENDIAN);
 | 
				
			||||||
| 
						 | 
					@ -743,11 +754,7 @@ class MessageControllerTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Iterator<Recipient> it = recipients.iterator();
 | 
					    Iterator<Recipient> it = recipients.iterator();
 | 
				
			||||||
    while (it.hasNext()) {
 | 
					    while (it.hasNext()) {
 | 
				
			||||||
      Recipient r = it.next();
 | 
					      writeMultiPayloadRecipient(bb, it.next());
 | 
				
			||||||
      UUID uuid = r.getUuid();
 | 
					 | 
				
			||||||
      long msb = uuid.getMostSignificantBits();
 | 
					 | 
				
			||||||
      long lsb = uuid.getLeastSignificantBits();
 | 
					 | 
				
			||||||
      writeMultiPayloadRecipient(bb, msb, lsb, r.getDeviceId(), r.getRegistrationId());
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // now write the actual message body (empty for now)
 | 
					    // now write the actual message body (empty for now)
 | 
				
			||||||
| 
						 | 
					@ -1003,4 +1010,62 @@ class MessageControllerTest {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return builder.build();
 | 
					    return builder.build();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static Recipient genRecipient(Random rng) {
 | 
				
			||||||
 | 
					    UUID u1 = UUID.randomUUID(); // non-null
 | 
				
			||||||
 | 
					    long d1 = rng.nextLong() & 0x3fffffffffffffffL + 1; // 1 to 4611686018427387903
 | 
				
			||||||
 | 
					    int dr1 = rng.nextInt() & 0xffff; // 0 to 65535
 | 
				
			||||||
 | 
					    byte[] perKeyBytes = new byte[48]; // size=48, non-null
 | 
				
			||||||
 | 
					    rng.nextBytes(perKeyBytes);
 | 
				
			||||||
 | 
					    return new Recipient(u1, d1, dr1, perKeyBytes);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  private static void roundTripVarint(long expected, byte [] bytes) throws Exception {
 | 
				
			||||||
 | 
					    ByteBuffer bb = ByteBuffer.wrap(bytes);
 | 
				
			||||||
 | 
					    writePayloadDeviceId(bb, expected);
 | 
				
			||||||
 | 
					    InputStream stream = new ByteArrayInputStream(bytes, 0, bb.position());
 | 
				
			||||||
 | 
					    long got = MultiRecipientMessageProvider.readVarint(stream);
 | 
				
			||||||
 | 
					    assertEquals(expected, got, String.format("encoded as: %s", Arrays.toString(bytes)));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testVarintPayload() throws Exception {
 | 
				
			||||||
 | 
					    Random rng = new Random();
 | 
				
			||||||
 | 
					    byte[] bytes = new byte[12];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // some static test cases
 | 
				
			||||||
 | 
					    for (long i = 1L; i <= 10L; i++) {
 | 
				
			||||||
 | 
					      roundTripVarint(i, bytes);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    roundTripVarint(Long.MAX_VALUE, bytes);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (int i = 0; i < 1000; i++) {
 | 
				
			||||||
 | 
					      // we need to ensure positive device IDs
 | 
				
			||||||
 | 
					      long start = rng.nextLong() & Long.MAX_VALUE;
 | 
				
			||||||
 | 
					      if (start == 0L) start = 1L;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // run the test for this case
 | 
				
			||||||
 | 
					      roundTripVarint(start, bytes);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testMultiPayloadRoundtrip() throws Exception {
 | 
				
			||||||
 | 
					    Random rng = new java.util.Random();
 | 
				
			||||||
 | 
					    List<Recipient> expected = new LinkedList<>();
 | 
				
			||||||
 | 
					    for(int i = 0; i < 100; i++) {
 | 
				
			||||||
 | 
					      expected.add(genRecipient(rng));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    byte[] buffer = new byte[100 + expected.size() * 100];
 | 
				
			||||||
 | 
					    InputStream entityStream = initializeMultiPayload(expected, buffer);
 | 
				
			||||||
 | 
					    MultiRecipientMessageProvider provider = new MultiRecipientMessageProvider();
 | 
				
			||||||
 | 
					    // the provider ignores the headers, java reflection, etc. so we don't use those here.
 | 
				
			||||||
 | 
					    MultiRecipientMessage res = provider.readFrom(null, null, null, null, null, entityStream);
 | 
				
			||||||
 | 
					    List<Recipient> got = Arrays.asList(res.getRecipients());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assertEquals(expected, got);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue