From aa99e202b47d11788fd79212bb6a636992792baa Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 8 Feb 2021 11:45:57 -0500 Subject: [PATCH] Clarify behavioral contract of the pre-key store --- .../controllers/KeysController.java | 27 ++++---- .../textsecuregcm/storage/KeyRecord.java | 62 ------------------- .../textsecuregcm/storage/KeysDynamoDb.java | 17 ++--- .../storage/mappers/KeyRecordRowMapper.java | 25 -------- .../storage/KeysDynamoDbTest.java | 16 ++--- .../tests/controllers/KeysControllerTest.java | 24 +++---- 6 files changed, 39 insertions(+), 132 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index caf934de2..76f145f97 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -21,7 +21,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.KeyRecord; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import javax.validation.Valid; @@ -35,8 +34,10 @@ import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import java.util.Collections; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -120,28 +121,22 @@ public class KeysController { rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getNumber() + "." + deviceId); } - List targetKeys = getLocalKeys(target.get(), deviceId); - List devices = new LinkedList<>(); + Map preKeysByDeviceId = getLocalKeys(target.get(), deviceId); + List responseItems = new LinkedList<>(); for (Device device : target.get().getDevices()) { if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) { SignedPreKey signedPreKey = device.getSignedPreKey(); - PreKey preKey = null; - - for (KeyRecord keyRecord : targetKeys) { - if (keyRecord.getDeviceId() == device.getId()) { - preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey()); - } - } + PreKey preKey = preKeysByDeviceId.get(device.getId()); if (signedPreKey != null || preKey != null) { - devices.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey)); + responseItems.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey)); } } } - if (devices.isEmpty()) return Optional.empty(); - else return Optional.of(new PreKeyResponse(target.get().getIdentityKey(), devices)); + if (responseItems.isEmpty()) return Optional.empty(); + else return Optional.of(new PreKeyResponse(target.get().getIdentityKey(), responseItems)); } @Timed @@ -172,7 +167,7 @@ public class KeysController { else return Optional.empty(); } - private List getLocalKeys(Account destination, String deviceIdSelector) { + private Map getLocalKeys(Account destination, String deviceIdSelector) { try { if (deviceIdSelector.equals("*")) { return keysDynamoDb.take(destination); @@ -180,7 +175,9 @@ public class KeysController { long deviceId = Long.parseLong(deviceIdSelector); - return keysDynamoDb.take(destination, deviceId); + return keysDynamoDb.take(destination, deviceId) + .map(preKey -> Map.of(deviceId, preKey)) + .orElse(Collections.emptyMap()); } catch (NumberFormatException e) { throw new WebApplicationException(Response.status(422).build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java deleted file mode 100644 index c4f79487c..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeyRecord.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import java.util.Objects; - -public class KeyRecord { - - private long id; - private String number; - private long deviceId; - private long keyId; - private String publicKey; - - public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) { - this.id = id; - this.number = number; - this.deviceId = deviceId; - this.keyId = keyId; - this.publicKey = publicKey; - } - - public long getId() { - return id; - } - - public String getNumber() { - return number; - } - - public long getDeviceId() { - return deviceId; - } - - public long getKeyId() { - return keyId; - } - - public String getPublicKey() { - return publicKey; - } - - @Override - public boolean equals(final Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - final KeyRecord keyRecord = (KeyRecord)o; - return id == keyRecord.id && - deviceId == keyRecord.deviceId && - keyId == keyRecord.keyId && - Objects.equals(number, keyRecord.number) && - Objects.equals(publicKey, keyRecord.publicKey); - } - - @Override - public int hashCode() { - return Objects.hash(id, number, deviceId, keyId, publicKey); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java index 009afaf60..482a442b8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java @@ -25,8 +25,10 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; import static com.codahale.metrics.MetricRegistry.name; @@ -69,7 +71,7 @@ public class KeysDynamoDb extends AbstractDynamoDbStore { }); } - public List take(final Account account, final long deviceId) { + public Optional take(final Account account, final long deviceId) { return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { final byte[] partitionKey = getPartitionKey(account.getUuid()); @@ -90,29 +92,28 @@ public class KeysDynamoDb extends AbstractDynamoDbStore { final DeleteItemOutcome outcome = table.deleteItem(deleteItemSpec); if (outcome.getItem() != null) { - final PreKey preKey = getPreKeyFromItem(outcome.getItem()); - return List.of(new KeyRecord(-1, account.getNumber(), deviceId, preKey.getKeyId(), preKey.getPublicKey())); + return Optional.of(getPreKeyFromItem(outcome.getItem())); } contestedKeys++; } - return Collections.emptyList(); + return Optional.empty(); } finally { CONTESTED_KEY_DISTRIBUTION.record(contestedKeys); } }); } - public List take(final Account account) { + public Map take(final Account account) { return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { - final List keyRecords = new ArrayList<>(); + final Map preKeysByDeviceId = new HashMap<>(); for (final Device device : account.getDevices()) { - keyRecords.addAll(take(account, device.getId())); + take(account, device.getId()).ifPresent(preKey -> preKeysByDeviceId.put(device.getId(), preKey)); } - return keyRecords; + return preKeysByDeviceId; }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java deleted file mode 100644 index 44573d920..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/KeyRecordRowMapper.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage.mappers; - -import org.jdbi.v3.core.mapper.RowMapper; -import org.jdbi.v3.core.statement.StatementContext; -import org.whispersystems.textsecuregcm.storage.KeyRecord; - -import java.sql.ResultSet; -import java.sql.SQLException; - -public class KeyRecordRowMapper implements RowMapper { - - @Override - public KeyRecord map(ResultSet resultSet, StatementContext ctx) throws SQLException { - return new KeyRecord(resultSet.getLong("id"), - resultSet.getString("number"), - resultSet.getLong("device_id"), - resultSet.getLong("key_id"), - resultSet.getString("public_key")); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java index 1fcc58880..b3170b2b8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDbTest.java @@ -11,8 +11,9 @@ import org.junit.Test; import org.whispersystems.textsecuregcm.entities.PreKey; import java.util.Collections; -import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -70,7 +71,7 @@ public class KeysDynamoDbTest { when(secondDevice.getId()).thenReturn(DEVICE_ID + 1); when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice)); - assertEquals(Collections.emptyList(), keysDynamoDb.take(account)); + assertEquals(Collections.emptyMap(), keysDynamoDb.take(account)); final PreKey firstDevicePreKey = new PreKey(1, "public-key"); final PreKey secondDevicePreKey = new PreKey(2, "second-key"); @@ -78,23 +79,22 @@ public class KeysDynamoDbTest { keysDynamoDb.store(account, DEVICE_ID, List.of(firstDevicePreKey)); keysDynamoDb.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey)); - final Set expectedKeys = Set.of( - new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, firstDevicePreKey.getKeyId(), firstDevicePreKey.getPublicKey()), - new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID + 1, secondDevicePreKey.getKeyId(), secondDevicePreKey.getPublicKey())); + final Map expectedKeys = Map.of(DEVICE_ID, firstDevicePreKey, + DEVICE_ID + 1, secondDevicePreKey); - assertEquals(expectedKeys, new HashSet<>(keysDynamoDb.take(account))); + assertEquals(expectedKeys, keysDynamoDb.take(account)); assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID)); assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1)); } @Test public void testTakeAccountAndDeviceId() { - assertEquals(Collections.emptyList(), keysDynamoDb.take(account, DEVICE_ID)); + assertEquals(Optional.empty(), keysDynamoDb.take(account, DEVICE_ID)); final PreKey preKey = new PreKey(1, "public-key"); keysDynamoDb.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); - assertEquals(List.of(new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, preKey.getKeyId(), preKey.getPublicKey())), keysDynamoDb.take(account, DEVICE_ID)); + assertEquals(Optional.of(preKey), keysDynamoDb.take(account, DEVICE_ID)); assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 954504331..f6813e56d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.KeyRecord; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @@ -38,6 +37,7 @@ import javax.ws.rs.core.Response; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -59,10 +59,10 @@ public class KeysControllerTest { private static int SAMPLE_REGISTRATION_ID2 = 1002; private static int SAMPLE_REGISTRATION_ID4 = 1555; - private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1"); - private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3"); - private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5"); - private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6"); + private final PreKey SAMPLE_KEY = new PreKey(1234, "test1"); + private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3"); + private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5"); + private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6"); private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" ); @@ -140,16 +140,12 @@ public class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - List singleDevice = new LinkedList<>(); - singleDevice.add(SAMPLE_KEY); - when(keysDynamoDb.take(eq(existsAccount), eq(1L))).thenReturn(singleDevice); + when(keysDynamoDb.take(eq(existsAccount), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY)); - List multiDevice = new LinkedList<>(); - multiDevice.add(SAMPLE_KEY); - multiDevice.add(SAMPLE_KEY2); - multiDevice.add(SAMPLE_KEY3); - multiDevice.add(SAMPLE_KEY4); - when(keysDynamoDb.take(existsAccount)).thenReturn(multiDevice); + when(keysDynamoDb.take(existsAccount)).thenReturn(Map.of(1L, SAMPLE_KEY, + 2L, SAMPLE_KEY2, + 3L, SAMPLE_KEY3, + 4L, SAMPLE_KEY4)); when(keysDynamoDb.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5);