diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index df8c2923e..863beb9ce 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -11,6 +11,7 @@ import io.dropwizard.auth.Auth; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -77,7 +78,7 @@ public class KeysController { @GET @Produces(MediaType.APPLICATION_JSON) public PreKeyCount getStatus(@Auth AuthenticatedAccount auth) { - int count = keys.getCount(auth.getAccount(), auth.getAuthenticatedDevice().getId()); + int count = keys.getCount(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId()); if (count > 0) { count = count - 1; @@ -109,7 +110,7 @@ public class KeysController { }); } - keys.store(account, device.getId(), preKeys.getPreKeys()); + keys.store(account.getUuid(), device.getId(), preKeys.getPreKeys()); } @Timed @@ -211,18 +212,26 @@ public class KeysController { } private Map getLocalKeys(Account destination, String deviceIdSelector) { - try { - if (deviceIdSelector.equals("*")) { - return keys.take(destination); + final Map preKeys; + + if (deviceIdSelector.equals("*")) { + preKeys = new HashMap<>(); + + for (final Device device : destination.getDevices()) { + keys.take(destination.getUuid(), device.getId()).ifPresent(preKey -> preKeys.put(device.getId(), preKey)); } + } else { + try { + long deviceId = Long.parseLong(deviceIdSelector); - long deviceId = Long.parseLong(deviceIdSelector); - - return keys.take(destination, deviceId) - .map(preKey -> Map.of(deviceId, preKey)) - .orElse(Collections.emptyMap()); - } catch (NumberFormatException e) { - throw new WebApplicationException(Response.status(422).build()); + preKeys = keys.take(destination.getUuid(), deviceId) + .map(preKey -> Map.of(deviceId, preKey)) + .orElse(Collections.emptyMap()); + } catch (NumberFormatException e) { + throw new WebApplicationException(Response.status(422).build()); + } } + + return preKeys; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java index f5d94d0f1..9b7255b62 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java @@ -13,7 +13,6 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -42,7 +41,6 @@ public class Keys extends AbstractDynamoDbStore { private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys")); private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice")); - private static final Timer TAKE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(Keys.class, "takeKeyForAccount")); private static final Timer GET_KEY_COUNT_TIMER = Metrics.timer(name(Keys.class, "getKeyCount")); private static final Timer DELETE_KEYS_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForDevice")); private static final Timer DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForAccount")); @@ -54,16 +52,16 @@ public class Keys extends AbstractDynamoDbStore { this.tableName = tableName; } - public void store(final Account account, final long deviceId, final List keys) { + public void store(final UUID identifier, final long deviceId, final List keys) { STORE_KEYS_TIMER.record(() -> { - delete(account.getUuid(), deviceId); + delete(identifier, deviceId); writeInBatches(keys, batch -> { List items = new ArrayList<>(); for (final PreKey preKey : batch) { items.add(WriteRequest.builder() .putRequest(PutRequest.builder() - .item(getItemFromPreKey(account.getUuid(), deviceId, preKey)) + .item(getItemFromPreKey(identifier, deviceId, preKey)) .build()) .build()); } @@ -72,9 +70,9 @@ public class Keys extends AbstractDynamoDbStore { }); } - public Optional take(final Account account, final long deviceId) { + public Optional take(final UUID identifier, final long deviceId) { return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { - final AttributeValue partitionKey = getPartitionKey(account.getUuid()); + final AttributeValue partitionKey = getPartitionKey(identifier); QueryRequest queryRequest = QueryRequest.builder() .tableName(tableName) .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") @@ -113,26 +111,14 @@ public class Keys extends AbstractDynamoDbStore { }); } - public Map take(final Account account) { - return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { - final Map preKeysByDeviceId = new HashMap<>(); - - for (final Device device : account.getDevices()) { - take(account, device.getId()).ifPresent(preKey -> preKeysByDeviceId.put(device.getId(), preKey)); - } - - return preKeysByDeviceId; - }); - } - - public int getCount(final Account account, final long deviceId) { + public int getCount(final UUID identifier, final long deviceId) { return GET_KEY_COUNT_TIMER.record(() -> { QueryRequest queryRequest = QueryRequest.builder() .tableName(tableName) .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(account.getUuid()), + ":uuid", getPartitionKey(identifier), ":sortprefix", getSortKeyPrefix(deviceId))) .select(Select.COUNT) .consistentRead(false) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java index 8b2ac3319..3dc83859c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java @@ -5,135 +5,99 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; import org.whispersystems.textsecuregcm.entities.PreKey; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class KeysTest { - private Account account; private Keys keys; @ClassRule public static KeysDynamoDbRule dynamoDbRule = new KeysDynamoDbRule(); - private static final String ACCOUNT_NUMBER = "+18005551234"; + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final long DEVICE_ID = 1L; @Before public void setup() { keys = new Keys(dynamoDbRule.getDynamoDbClient(), KeysDynamoDbRule.TABLE_NAME); - - account = mock(Account.class); - when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); - when(account.getUuid()).thenReturn(UUID.randomUUID()); } @Test public void testStore() { assertEquals("Initial pre-key count for an account should be zero", - 0, keys.getCount(account, DEVICE_ID)); + 0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); - assertEquals(1, keys.getCount(account, DEVICE_ID)); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); assertEquals("Repeatedly storing same key should have no effect", - 1, keys.getCount(account, DEVICE_ID)); + 1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(account, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(2, "different-public-key"))); assertEquals("Inserting a new key should overwrite all prior keys for the given account/device", - 1, keys.getCount(account, DEVICE_ID)); + 1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(account, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(3, "third-public-key"), new PreKey(4, "fourth-public-key"))); assertEquals("Inserting multiple new keys should overwrite all prior keys for the given account/device", - 2, keys.getCount(account, DEVICE_ID)); - } - - @Test - public void testTakeAccount() { - final Device firstDevice = mock(Device.class); - final Device secondDevice = mock(Device.class); - - when(firstDevice.getId()).thenReturn(DEVICE_ID); - when(secondDevice.getId()).thenReturn(DEVICE_ID + 1); - when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice)); - - assertEquals(Collections.emptyMap(), keys.take(account)); - - final PreKey firstDevicePreKey = new PreKey(1, "public-key"); - final PreKey secondDevicePreKey = new PreKey(2, "second-key"); - - keys.store(account, DEVICE_ID, List.of(firstDevicePreKey)); - keys.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey)); - - final Map expectedKeys = Map.of(DEVICE_ID, firstDevicePreKey, - DEVICE_ID + 1, secondDevicePreKey); - - assertEquals(expectedKeys, keys.take(account)); - assertEquals(0, keys.getCount(account, DEVICE_ID)); - assertEquals(0, keys.getCount(account, DEVICE_ID + 1)); + 2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); } @Test public void testTakeAccountAndDeviceId() { - assertEquals(Optional.empty(), keys.take(account, DEVICE_ID)); + assertEquals(Optional.empty(), keys.take(ACCOUNT_UUID, DEVICE_ID)); final PreKey preKey = new PreKey(1, "public-key"); - keys.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); - assertEquals(Optional.of(preKey), keys.take(account, DEVICE_ID)); - assertEquals(1, keys.getCount(account, DEVICE_ID)); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key"))); + assertEquals(Optional.of(preKey), keys.take(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); } @Test public void testGetCount() { - assertEquals(0, keys.getCount(account, DEVICE_ID)); + assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); - keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"))); - assertEquals(1, keys.getCount(account, DEVICE_ID)); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"))); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); } @Test public void testDeleteByAccount() { - keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); - keys.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); - assertEquals(2, keys.getCount(account, DEVICE_ID)); - assertEquals(1, keys.getCount(account, DEVICE_ID + 1)); + assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); - keys.delete(account.getUuid()); + keys.delete(ACCOUNT_UUID); - assertEquals(0, keys.getCount(account, DEVICE_ID)); - assertEquals(0, keys.getCount(account, DEVICE_ID + 1)); + assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); } @Test public void testDeleteByAccountAndDevice() { - keys.store(account, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); - keys.store(account, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); + keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(new PreKey(1, "public-key"), new PreKey(2, "different-public-key"))); + keys.store(ACCOUNT_UUID, DEVICE_ID + 1, List.of(new PreKey(3, "public-key-for-different-device"))); - assertEquals(2, keys.getCount(account, DEVICE_ID)); - assertEquals(1, keys.getCount(account, DEVICE_ID + 1)); + assertEquals(2, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); - keys.delete(account.getUuid(), DEVICE_ID); + keys.delete(ACCOUNT_UUID, DEVICE_ID); - assertEquals(0, keys.getCount(account, DEVICE_ID)); - assertEquals(1, keys.getCount(account, DEVICE_ID + 1)); + assertEquals(0, keys.getCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keys.getCount(ACCOUNT_UUID, DEVICE_ID + 1)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index bf6d2fae6..63645d3ce 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -15,7 +15,6 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid; import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; @@ -26,7 +25,6 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -142,6 +140,7 @@ class KeysControllerTest { when(sampleDevice3.getId()).thenReturn(3L); when(sampleDevice4.getId()).thenReturn(4L); + when(existsAccount.getUuid()).thenReturn(EXISTS_UUID); when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice)); when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2)); when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3)); @@ -161,14 +160,9 @@ class KeysControllerTest { when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - when(KEYS.take(eq(existsAccount), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY)); + when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); - when(KEYS.take(existsAccount)).thenReturn(Map.of(1L, SAMPLE_KEY, - 2L, SAMPLE_KEY2, - 3L, SAMPLE_KEY3, - 4L, SAMPLE_KEY4)); - - when(KEYS.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5); + when(KEYS.getCount(AuthHelper.VALID_UUID, 1)).thenReturn(5); when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(VALID_DEVICE_SIGNED_KEY); when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); @@ -198,7 +192,7 @@ class KeysControllerTest { assertThat(result.getCount()).isEqualTo(4); - verify(KEYS).getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L)); + verify(KEYS).getCount(AuthHelper.VALID_UUID, 1); } @@ -257,7 +251,7 @@ class KeysControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(KEYS).take(eq(existsAccount), eq(1L)); + verify(KEYS).take(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -275,7 +269,7 @@ class KeysControllerTest { assertThat(result.getDevice(1).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getDevice(1).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey()); - verify(KEYS).take(eq(existsAccount), eq(1L)); + verify(KEYS).take(EXISTS_UUID, 1); verifyNoMoreInteractions(KEYS); } @@ -321,8 +315,13 @@ class KeysControllerTest { @Test void validMultiRequestTestV2() { + when(KEYS.take(EXISTS_UUID, 1)).thenReturn(Optional.of(SAMPLE_KEY)); + when(KEYS.take(EXISTS_UUID, 2)).thenReturn(Optional.of(SAMPLE_KEY2)); + when(KEYS.take(EXISTS_UUID, 3)).thenReturn(Optional.of(SAMPLE_KEY3)); + when(KEYS.take(EXISTS_UUID, 4)).thenReturn(Optional.of(SAMPLE_KEY4)); + PreKeyResponse results = resources.getJerseyTest() - .target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString())) + .target(String.format("/v2/keys/%s/*", EXISTS_UUID)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(PreKeyResponse.class); @@ -332,8 +331,8 @@ class KeysControllerTest { PreKey signedPreKey = results.getDevice(1).getSignedPreKey(); PreKey preKey = results.getDevice(1).getPreKey(); - long registrationId = results.getDevice(1).getRegistrationId(); - long deviceId = results.getDevice(1).getDeviceId(); + long registrationId = results.getDevice(1).getRegistrationId(); + long deviceId = results.getDevice(1).getDeviceId(); assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); @@ -365,7 +364,10 @@ class KeysControllerTest { assertThat(signedPreKey).isNull(); assertThat(deviceId).isEqualTo(4); - verify(KEYS).take(eq(existsAccount)); + verify(KEYS).take(EXISTS_UUID, 1); + verify(KEYS).take(EXISTS_UUID, 2); + verify(KEYS).take(EXISTS_UUID, 3); + verify(KEYS).take(EXISTS_UUID, 4); verifyNoMoreInteractions(KEYS); } @@ -433,8 +435,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture()); + ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); + verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -467,8 +469,8 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); - ArgumentCaptor listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture()); + ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); + verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(1L), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1);