Clarify behavioral contract of the pre-key store

This commit is contained in:
Jon Chambers 2021-02-08 11:45:57 -05:00 committed by GitHub
parent 04728ea4bc
commit aa99e202b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 132 deletions

View File

@ -21,7 +21,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import javax.validation.Valid;
@ -35,8 +34,10 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -120,28 +121,22 @@ public class KeysController {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getNumber() + "." + deviceId);
}
List<KeyRecord> targetKeys = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> devices = new LinkedList<>();
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> responseItems = new LinkedList<>();
for (Device device : target.get().getDevices()) {
if (device.isEnabled() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey();
PreKey preKey = null;
for (KeyRecord keyRecord : targetKeys) {
if (keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKey(keyRecord.getKeyId(), keyRecord.getPublicKey());
}
}
PreKey preKey = preKeysByDeviceId.get(device.getId());
if (signedPreKey != null || preKey != null) {
devices.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey));
responseItems.add(new PreKeyResponseItem(device.getId(), device.getRegistrationId(), signedPreKey, preKey));
}
}
}
if (devices.isEmpty()) return Optional.empty();
else return Optional.of(new PreKeyResponse(target.get().getIdentityKey(), devices));
if (responseItems.isEmpty()) return Optional.empty();
else return Optional.of(new PreKeyResponse(target.get().getIdentityKey(), responseItems));
}
@Timed
@ -172,7 +167,7 @@ public class KeysController {
else return Optional.empty();
}
private List<KeyRecord> getLocalKeys(Account destination, String deviceIdSelector) {
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector) {
try {
if (deviceIdSelector.equals("*")) {
return keysDynamoDb.take(destination);
@ -180,7 +175,9 @@ public class KeysController {
long deviceId = Long.parseLong(deviceIdSelector);
return keysDynamoDb.take(destination, deviceId);
return keysDynamoDb.take(destination, deviceId)
.map(preKey -> Map.of(deviceId, preKey))
.orElse(Collections.emptyMap());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}

View File

@ -1,62 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.Objects;
public class KeyRecord {
private long id;
private String number;
private long deviceId;
private long keyId;
private String publicKey;
public KeyRecord(long id, String number, long deviceId, long keyId, String publicKey) {
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.keyId = keyId;
this.publicKey = publicKey;
}
public long getId() {
return id;
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public long getKeyId() {
return keyId;
}
public String getPublicKey() {
return publicKey;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final KeyRecord keyRecord = (KeyRecord)o;
return id == keyRecord.id &&
deviceId == keyRecord.deviceId &&
keyId == keyRecord.keyId &&
Objects.equals(number, keyRecord.number) &&
Objects.equals(publicKey, keyRecord.publicKey);
}
@Override
public int hashCode() {
return Objects.hash(id, number, deviceId, keyId, publicKey);
}
}

View File

@ -25,8 +25,10 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
@ -69,7 +71,7 @@ public class KeysDynamoDb extends AbstractDynamoDbStore {
});
}
public List<KeyRecord> take(final Account account, final long deviceId) {
public Optional<PreKey> take(final Account account, final long deviceId) {
return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> {
final byte[] partitionKey = getPartitionKey(account.getUuid());
@ -90,29 +92,28 @@ public class KeysDynamoDb extends AbstractDynamoDbStore {
final DeleteItemOutcome outcome = table.deleteItem(deleteItemSpec);
if (outcome.getItem() != null) {
final PreKey preKey = getPreKeyFromItem(outcome.getItem());
return List.of(new KeyRecord(-1, account.getNumber(), deviceId, preKey.getKeyId(), preKey.getPublicKey()));
return Optional.of(getPreKeyFromItem(outcome.getItem()));
}
contestedKeys++;
}
return Collections.emptyList();
return Optional.empty();
} finally {
CONTESTED_KEY_DISTRIBUTION.record(contestedKeys);
}
});
}
public List<KeyRecord> take(final Account account) {
public Map<Long, PreKey> take(final Account account) {
return TAKE_KEYS_FOR_ACCOUNT_TIMER.record(() -> {
final List<KeyRecord> keyRecords = new ArrayList<>();
final Map<Long, PreKey> preKeysByDeviceId = new HashMap<>();
for (final Device device : account.getDevices()) {
keyRecords.addAll(take(account, device.getId()));
take(account, device.getId()).ifPresent(preKey -> preKeysByDeviceId.put(device.getId(), preKey));
}
return keyRecords;
return preKeysByDeviceId;
});
}

View File

@ -1,25 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import java.sql.ResultSet;
import java.sql.SQLException;
public class KeyRecordRowMapper implements RowMapper<KeyRecord> {
@Override
public KeyRecord map(ResultSet resultSet, StatementContext ctx) throws SQLException {
return new KeyRecord(resultSet.getLong("id"),
resultSet.getString("number"),
resultSet.getLong("device_id"),
resultSet.getLong("key_id"),
resultSet.getString("public_key"));
}
}

View File

@ -11,8 +11,9 @@ import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.PreKey;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
@ -70,7 +71,7 @@ public class KeysDynamoDbTest {
when(secondDevice.getId()).thenReturn(DEVICE_ID + 1);
when(account.getDevices()).thenReturn(Set.of(firstDevice, secondDevice));
assertEquals(Collections.emptyList(), keysDynamoDb.take(account));
assertEquals(Collections.emptyMap(), keysDynamoDb.take(account));
final PreKey firstDevicePreKey = new PreKey(1, "public-key");
final PreKey secondDevicePreKey = new PreKey(2, "second-key");
@ -78,23 +79,22 @@ public class KeysDynamoDbTest {
keysDynamoDb.store(account, DEVICE_ID, List.of(firstDevicePreKey));
keysDynamoDb.store(account, DEVICE_ID + 1, List.of(secondDevicePreKey));
final Set<KeyRecord> expectedKeys = Set.of(
new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, firstDevicePreKey.getKeyId(), firstDevicePreKey.getPublicKey()),
new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID + 1, secondDevicePreKey.getKeyId(), secondDevicePreKey.getPublicKey()));
final Map<Long, PreKey> expectedKeys = Map.of(DEVICE_ID, firstDevicePreKey,
DEVICE_ID + 1, secondDevicePreKey);
assertEquals(expectedKeys, new HashSet<>(keysDynamoDb.take(account)));
assertEquals(expectedKeys, keysDynamoDb.take(account));
assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID));
assertEquals(0, keysDynamoDb.getCount(account, DEVICE_ID + 1));
}
@Test
public void testTakeAccountAndDeviceId() {
assertEquals(Collections.emptyList(), keysDynamoDb.take(account, DEVICE_ID));
assertEquals(Optional.empty(), keysDynamoDb.take(account, DEVICE_ID));
final PreKey preKey = new PreKey(1, "public-key");
keysDynamoDb.store(account, DEVICE_ID, List.of(preKey, new PreKey(2, "different-pre-key")));
assertEquals(List.of(new KeyRecord(-1, ACCOUNT_NUMBER, DEVICE_ID, preKey.getKeyId(), preKey.getPublicKey())), keysDynamoDb.take(account, DEVICE_ID));
assertEquals(Optional.of(preKey), keysDynamoDb.take(account, DEVICE_ID));
assertEquals(1, keysDynamoDb.getCount(account, DEVICE_ID));
}

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -38,6 +37,7 @@ import javax.ws.rs.core.Response;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
@ -59,10 +59,10 @@ public class KeysControllerTest {
private static int SAMPLE_REGISTRATION_ID2 = 1002;
private static int SAMPLE_REGISTRATION_ID4 = 1555;
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1");
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3");
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5");
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6");
private final PreKey SAMPLE_KEY = new PreKey(1234, "test1");
private final PreKey SAMPLE_KEY2 = new PreKey(5667, "test3");
private final PreKey SAMPLE_KEY3 = new PreKey(334, "test5");
private final PreKey SAMPLE_KEY4 = new PreKey(336, "test6");
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey( 1111, "foofoo", "sig11" );
@ -140,16 +140,12 @@ public class KeysControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
List<KeyRecord> singleDevice = new LinkedList<>();
singleDevice.add(SAMPLE_KEY);
when(keysDynamoDb.take(eq(existsAccount), eq(1L))).thenReturn(singleDevice);
when(keysDynamoDb.take(eq(existsAccount), eq(1L))).thenReturn(Optional.of(SAMPLE_KEY));
List<KeyRecord> multiDevice = new LinkedList<>();
multiDevice.add(SAMPLE_KEY);
multiDevice.add(SAMPLE_KEY2);
multiDevice.add(SAMPLE_KEY3);
multiDevice.add(SAMPLE_KEY4);
when(keysDynamoDb.take(existsAccount)).thenReturn(multiDevice);
when(keysDynamoDb.take(existsAccount)).thenReturn(Map.of(1L, SAMPLE_KEY,
2L, SAMPLE_KEY2,
3L, SAMPLE_KEY3,
4L, SAMPLE_KEY4));
when(keysDynamoDb.getCount(eq(AuthHelper.VALID_ACCOUNT), eq(1L))).thenReturn(5);