From ade2e9c6cf36f291286c0a57d64b70c5367a6462 Mon Sep 17 00:00:00 2001 From: Katherine Yen Date: Wed, 19 Jul 2023 10:43:58 -0700 Subject: [PATCH] Define asynchronous `ProfilesManager` operations --- .../textsecuregcm/storage/Accounts.java | 14 +-- .../textsecuregcm/storage/Profiles.java | 25 ++++ .../storage/ProfilesManager.java | 89 +++++++++++--- .../textsecuregcm/util/AsyncTimerUtil.java | 15 +++ .../storage/ProfilesManagerTest.java | 111 +++++++++++++++++- .../textsecuregcm/storage/ProfilesTest.java | 22 +++- 6 files changed, 248 insertions(+), 28 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/AsyncTimerUtil.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 3dd61cf03..c6d7776e0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -28,11 +28,11 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; import java.util.function.Predicate; -import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -521,7 +521,7 @@ public class Accounts extends AbstractDynamoDbStore { @Nonnull public CompletionStage updateAsync(final Account account) { - return record(UPDATE_TIMER, () -> { + return AsyncTimerUtil.record(UPDATE_TIMER, () -> { final UpdateItemRequest updateItemRequest; try { // username, e164, and pni cannot be modified through this method @@ -676,7 +676,7 @@ public class Accounts extends AbstractDynamoDbStore { @Nonnull public CompletableFuture> getByAccountIdentifierAsync(final UUID uuid) { - return record(GET_BY_UUID_TIMER, () -> itemByKeyAsync(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)) + return AsyncTimerUtil.record(GET_BY_UUID_TIMER, () -> itemByKeyAsync(accountsTableName, KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid)) .thenApply(maybeItem -> maybeItem.map(Accounts::fromItem))) .toCompletableFuture(); } @@ -776,7 +776,7 @@ public class Accounts extends AbstractDynamoDbStore { final AttributeValue keyValue, final Predicate> predicate) { - return record(timer, () -> itemByKeyAsync(tableName, keyName, keyValue) + return AsyncTimerUtil.record(timer, () -> itemByKeyAsync(tableName, keyName, keyValue) .thenCompose(maybeItem -> maybeItem .filter(predicate) .map(item -> item.get(KEY_ACCOUNT_UUID)) @@ -934,12 +934,6 @@ public class Accounts extends AbstractDynamoDbStore { .build()) .build(); } - - @Nonnull - private static CompletionStage record(final Timer timer, final Supplier> toRecord) { - final Timer.Sample sample = Timer.start(); - return toRecord.get().whenComplete((ignoreT, ignoreE) -> sample.stop(timer)); - } @Nonnull private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java index 6e1d47cfe..970979458 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java @@ -19,7 +19,9 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; +import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.Util; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -95,6 +97,18 @@ public class Profiles { }); } + public CompletableFuture setAsync(final UUID uuid, final VersionedProfile profile) { + return AsyncTimerUtil.record(SET_PROFILES_TIMER, () -> dynamoDbAsyncClient.updateItem(UpdateItemRequest.builder() + .tableName(tableName) + .key(buildPrimaryKey(uuid, profile.getVersion())) + .updateExpression(buildUpdateExpression(profile)) + .expressionAttributeNames(UPDATE_EXPRESSION_ATTRIBUTE_NAMES) + .expressionAttributeValues(buildUpdateExpressionAttributeValues(profile)) + .build() + ).thenRun(Util.NOOP) + ).toCompletableFuture(); + } + private static Map buildPrimaryKey(final UUID uuid, final String version) { return Map.of( KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid), @@ -198,6 +212,17 @@ public class Profiles { }); } + public CompletableFuture> getAsync(final UUID uuid, final String version) { + return AsyncTimerUtil.record(GET_PROFILE_TIMER, () -> dynamoDbAsyncClient.getItem(GetItemRequest.builder() + .tableName(tableName) + .key(buildPrimaryKey(uuid, version)) + .consistentRead(true) + .build()) + .thenApply(response -> + response.hasItem() ? Optional.of(fromItem(response.item())) : Optional.empty()) + ).toCompletableFuture(); + } + private static VersionedProfile fromItem(final Map item) { return new VersionedProfile( AttributeValues.getString(item, ATTR_VERSION, null), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java index 915baf7a9..6b6e4cdde 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ProfilesManager.java @@ -11,10 +11,13 @@ import io.lettuce.core.RedisException; import java.io.IOException; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.Util; +import javax.annotation.Nullable; public class ProfilesManager { @@ -26,6 +29,7 @@ public class ProfilesManager { private final FaultTolerantRedisCluster cacheCluster; private final ObjectMapper mapper; + public ProfilesManager(final Profiles profiles, final FaultTolerantRedisCluster cacheCluster) { this.profiles = profiles; @@ -34,52 +38,105 @@ public class ProfilesManager { } public void set(UUID uuid, VersionedProfile versionedProfile) { - memcacheSet(uuid, versionedProfile); + redisSet(uuid, versionedProfile); profiles.set(uuid, versionedProfile); } + public CompletableFuture setAsync(UUID uuid, VersionedProfile versionedProfile) { + return profiles.setAsync(uuid, versionedProfile) + .thenCompose(ignored -> redisSetAsync(uuid, versionedProfile)); + } + public void deleteAll(UUID uuid) { - memcacheDelete(uuid); + redisDelete(uuid); profiles.deleteAll(uuid); } public Optional get(UUID uuid, String version) { - Optional profile = memcacheGet(uuid, version); + Optional profile = redisGet(uuid, version); if (profile.isEmpty()) { profile = profiles.get(uuid, version); - profile.ifPresent(versionedProfile -> memcacheSet(uuid, versionedProfile)); + profile.ifPresent(versionedProfile -> redisSet(uuid, versionedProfile)); } return profile; } - private void memcacheSet(UUID uuid, VersionedProfile profile) { + public CompletableFuture> getAsync(UUID uuid, String version) { + return redisGetAsync(uuid, version) + .thenCompose(maybeVersionedProfile -> maybeVersionedProfile + .map(versionedProfile -> CompletableFuture.completedFuture(maybeVersionedProfile)) + .orElseGet(() -> profiles.getAsync(uuid, version) + .thenCompose(maybeVersionedProfileFromDynamo -> maybeVersionedProfileFromDynamo + .map(profile -> redisSetAsync(uuid, profile).thenApply(ignored -> maybeVersionedProfileFromDynamo)) + .orElseGet(() -> CompletableFuture.completedFuture(maybeVersionedProfileFromDynamo))))); + } + + private void redisSet(UUID uuid, VersionedProfile profile) { try { final String profileJson = mapper.writeValueAsString(profile); - cacheCluster.useCluster(connection -> connection.sync().hset(CACHE_PREFIX + uuid.toString(), profile.getVersion(), profileJson)); + cacheCluster.useCluster(connection -> connection.sync().hset(getCacheKey(uuid), profile.getVersion(), profileJson)); } catch (JsonProcessingException e) { throw new IllegalArgumentException(e); } } - private Optional memcacheGet(UUID uuid, String version) { - try { - final String json = cacheCluster.withCluster(connection -> connection.sync().hget(CACHE_PREFIX + uuid.toString(), version)); + private CompletableFuture redisSetAsync(UUID uuid, VersionedProfile profile) { + final String profileJson; - if (json == null) return Optional.empty(); - else return Optional.of(mapper.readValue(json, VersionedProfile.class)); - } catch (IOException e) { - logger.warn("Error deserializing value...", e); - return Optional.empty(); + try { + profileJson = mapper.writeValueAsString(profile); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e); + } + + return cacheCluster.withCluster(connection -> + connection.async().hset(getCacheKey(uuid), profile.getVersion(), profileJson)) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } + + private Optional redisGet(UUID uuid, String version) { + try { + @Nullable final String json = cacheCluster.withCluster(connection -> connection.sync().hget(getCacheKey(uuid), version)); + + return parseProfileJson(json); } catch (RedisException e) { logger.warn("Redis exception", e); return Optional.empty(); } } - private void memcacheDelete(UUID uuid) { - cacheCluster.useCluster(connection -> connection.sync().del(CACHE_PREFIX + uuid.toString())); + private CompletableFuture> redisGetAsync(UUID uuid, String version) { + return cacheCluster.withCluster(connection -> + connection.async().hget(getCacheKey(uuid), version)) + .thenApply(this::parseProfileJson) + .exceptionally(throwable -> { + logger.warn("Failed to read versioned profile from Redis", throwable); + return Optional.empty(); + }) + .toCompletableFuture(); + } + + private Optional parseProfileJson(@Nullable final String maybeJson) { + try { + if (maybeJson != null) { + return Optional.of(mapper.readValue(maybeJson, VersionedProfile.class)); + } + return Optional.empty(); + } catch (final IOException e) { + logger.warn("Error deserializing value...", e); + return Optional.empty(); + } + } + + private void redisDelete(UUID uuid) { + cacheCluster.useCluster(connection -> connection.sync().del(getCacheKey(uuid))); + } + + private String getCacheKey(UUID uuid) { + return CACHE_PREFIX + uuid.toString(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/AsyncTimerUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/AsyncTimerUtil.java new file mode 100644 index 000000000..b7c87cda6 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/AsyncTimerUtil.java @@ -0,0 +1,15 @@ +package org.whispersystems.textsecuregcm.util; + +import io.micrometer.core.instrument.Timer; +import javax.annotation.Nonnull; +import java.util.concurrent.CompletionStage; +import java.util.function.Supplier; + +public class AsyncTimerUtil { + @Nonnull + public static CompletionStage record(final Timer timer, final Supplier> toRecord) { + final Timer.Sample sample = Timer.start(); + return toRecord.get().whenComplete((ignoreT, ignoreE) -> sample.stop(timer)); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesManagerTest.java index 51ef4f71e..628ddac50 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesManagerTest.java @@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -18,19 +19,25 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.lettuce.core.RedisException; +import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.util.Base64; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class ProfilesManagerTest { private Profiles profiles; private RedisAdvancedClusterCommands commands; + private RedisAdvancedClusterAsyncCommands asyncCommands; private ProfilesManager profilesManager; @@ -38,7 +45,11 @@ public class ProfilesManagerTest { void setUp() { //noinspection unchecked commands = mock(RedisAdvancedClusterCommands.class); - final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.builder().stringCommands(commands).build(); + asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); + final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.builder() + .stringCommands(commands) + .stringAsyncCommands(asyncCommands) + .build(); profiles = mock(Profiles.class); @@ -63,6 +74,25 @@ public class ProfilesManagerTest { verifyNoMoreInteractions(profiles); } + @Test + public void testGetProfileAsyncInCache() { + UUID uuid = UUID.randomUUID(); + + when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn( + MockRedisFuture.completedFuture("{\"version\": \"someversion\", \"name\": \"somename\", \"avatar\": \"someavatar\", \"commitment\":\"" + Base64.getEncoder().encodeToString("somecommitment".getBytes()) + "\"}")); + + Optional profile = profilesManager.getAsync(uuid, "someversion").join(); + + assertTrue(profile.isPresent()); + assertEquals(profile.get().getName(), "somename"); + assertEquals(profile.get().getAvatar(), "someavatar"); + assertThat(profile.get().getCommitment()).isEqualTo("somecommitment".getBytes()); + + verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); + verifyNoMoreInteractions(asyncCommands); + verifyNoMoreInteractions(profiles); + } + @Test public void testGetProfileNotInCache() { UUID uuid = UUID.randomUUID(); @@ -85,6 +115,29 @@ public class ProfilesManagerTest { verifyNoMoreInteractions(profiles); } + @Test + public void testGetProfileAsyncNotInCache() { + UUID uuid = UUID.randomUUID(); + VersionedProfile profile = new VersionedProfile("someversion", "somename", "someavatar", null, null, + null, "somecommitment".getBytes()); + + when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(MockRedisFuture.completedFuture(null)); + when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); + when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); + + Optional retrieved = profilesManager.getAsync(uuid, "someversion").join(); + + assertTrue(retrieved.isPresent()); + assertSame(retrieved.get(), profile); + + verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); + verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); + verifyNoMoreInteractions(asyncCommands); + + verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion")); + verifyNoMoreInteractions(profiles); + } + @Test public void testGetProfileBrokenCache() { UUID uuid = UUID.randomUUID(); @@ -106,4 +159,60 @@ public class ProfilesManagerTest { verify(profiles, times(1)).get(eq(uuid), eq("someversion")); verifyNoMoreInteractions(profiles); } + + @Test + public void testGetProfileAsyncBrokenCache() { + UUID uuid = UUID.randomUUID(); + VersionedProfile profile = new VersionedProfile("someversion", "somename", "someavatar", null, null, + null, "somecommitment".getBytes()); + + when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost"))); + when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); + when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); + + Optional retrieved = profilesManager.getAsync(uuid, "someversion").join(); + + assertTrue(retrieved.isPresent()); + assertSame(retrieved.get(), profile); + + verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); + verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); + verifyNoMoreInteractions(asyncCommands); + + verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion")); + verifyNoMoreInteractions(profiles); + } + + @Test + public void testSet() { + UUID uuid = UUID.randomUUID(); + VersionedProfile profile = new VersionedProfile("someversion", "somename", "someavatar", null, null, + null, "somecommitment".getBytes()); + + profilesManager.set(uuid, profile); + + verify(commands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), any()); + verifyNoMoreInteractions(commands); + + verify(profiles, times(1)).set(eq(uuid), eq(profile)); + verifyNoMoreInteractions(profiles); + } + + @Test + public void testSetAsync() { + UUID uuid = UUID.randomUUID(); + VersionedProfile profile = new VersionedProfile("someversion", "somename", "someavatar", null, null, + null, "somecommitment".getBytes()); + + when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); + when(profiles.setAsync(eq(uuid), eq(profile))).thenReturn(CompletableFuture.completedFuture(null)); + + profilesManager.setAsync(uuid, profile).join(); + + verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), any()); + verifyNoMoreInteractions(asyncCommands); + + verify(profiles, times(1)).setAsync(eq(uuid), eq(profile)); + verifyNoMoreInteractions(profiles); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java index e85b8661c..af37b1151 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -24,7 +25,8 @@ import java.util.Optional; import java.util.UUID; import java.util.stream.Stream; -public abstract class ProfilesTest { +@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +public class ProfilesTest { @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.PROFILES); @@ -56,6 +58,24 @@ public abstract class ProfilesTest { assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); } + @Test + void testSetGetAsync() { + UUID uuid = UUID.randomUUID(); + VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", "emoji", + "the very model of a modern major general", + null, "acommitment".getBytes()); + profiles.setAsync(uuid, profile).join(); + + Optional retrieved = profiles.getAsync(uuid, "123").join(); + + assertThat(retrieved.isPresent()).isTrue(); + assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); + assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); + assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); + assertThat(retrieved.get().getAbout()).isEqualTo(profile.getAbout()); + assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); + } + @Test void testDeleteReset() { UUID uuid = UUID.randomUUID();