Delete avatars in ProfilesManager#deleteAll

This commit is contained in:
Chris Eager 2025-05-22 12:35:07 -05:00 committed by Chris Eager
parent 8491d18413
commit c1a66e0418
7 changed files with 132 additions and 53 deletions

View File

@ -394,6 +394,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final DynamoDbClient dynamoDbClient = config.getDynamoDbClientConfiguration() final DynamoDbClient dynamoDbClient = config.getDynamoDbClientConfiguration()
.buildSyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbSync")); .buildSyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbSync"));
final AwsCredentialsProvider cdnCredentialsProvider = config.getCdnConfiguration().credentials().build();
final S3AsyncClient asyncCdnS3Client = S3AsyncClient.builder()
.credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().region()))
.build();
BlockingQueue<Runnable> messageDeletionQueue = new LinkedBlockingQueue<>(); BlockingQueue<Runnable> messageDeletionQueue = new LinkedBlockingQueue<>();
Metrics.gaugeCollectionSize(name(getClass(), "messageDeletionQueueSize"), Collections.emptyList(), Metrics.gaugeCollectionSize(name(getClass(), "messageDeletionQueueSize"), Collections.emptyList(),
messageDeletionQueue); messageDeletionQueue);
@ -500,8 +506,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.scheduledExecutorService(name(getClass(), "remoteStorageRetry-%d")).threads(1).build(); .scheduledExecutorService(name(getClass(), "remoteStorageRetry-%d")).threads(1).build();
ScheduledExecutorService registrationIdentityTokenRefreshExecutor = environment.lifecycle() ScheduledExecutorService registrationIdentityTokenRefreshExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "registrationIdentityTokenRefresh-%d")).threads(1).build(); .scheduledExecutorService(name(getClass(), "registrationIdentityTokenRefresh-%d")).threads(1).build();
ScheduledExecutorService recurringConfigSyncExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "configSync-%d")).threads(1).build();
Scheduler messageDeliveryScheduler = Schedulers.fromExecutorService( Scheduler messageDeliveryScheduler = Schedulers.fromExecutorService(
ExecutorServiceMetrics.monitor(Metrics.globalRegistry, ExecutorServiceMetrics.monitor(Metrics.globalRegistry,
@ -610,7 +614,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,
storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration()); storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration());
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, asyncCdnS3Client, config.getCdnConfiguration().bucket());
MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler,
messageDeletionAsyncExecutor, clock); messageDeletionAsyncExecutor, clock);
ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases, ClientReleaseManager clientReleaseManager = new ClientReleaseManager(clientReleases,
@ -745,17 +749,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(virtualThreadPinEventMonitor); environment.lifecycle().manage(virtualThreadPinEventMonitor);
environment.lifecycle().manage(accountsManager); environment.lifecycle().manage(accountsManager);
final S3Client cdnS3Client = S3Client.builder()
AwsCredentialsProvider cdnCredentialsProvider = config.getCdnConfiguration().credentials().build();
S3Client cdnS3Client = S3Client.builder()
.credentialsProvider(cdnCredentialsProvider) .credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().region())) .region(Region.of(config.getCdnConfiguration().region()))
.httpClientBuilder(AwsCrtHttpClient.builder()) .httpClientBuilder(AwsCrtHttpClient.builder())
.build(); .build();
S3AsyncClient asyncCdnS3Client = S3AsyncClient.builder()
.credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().region()))
.build();
final GcsAttachmentGenerator gcsAttachmentGenerator = new GcsAttachmentGenerator( final GcsAttachmentGenerator gcsAttachmentGenerator = new GcsAttachmentGenerator(
config.getGcpAttachmentsConfiguration().domain(), config.getGcpAttachmentsConfiguration().domain(),

View File

@ -261,7 +261,12 @@ public class Profiles {
return AttributeValues.extractByteArray(attributeValue, PARSE_BYTE_ARRAY_COUNTER_NAME); return AttributeValues.extractByteArray(attributeValue, PARSE_BYTE_ARRAY_COUNTER_NAME);
} }
public CompletableFuture<Void> deleteAll(final UUID uuid) { /**
* Deletes all profile versions for the given UUID
*
* @return a list of avatar URLs to be deleted
*/
public CompletableFuture<List<String>> deleteAll(final UUID uuid) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
final AttributeValue uuidAttributeValue = AttributeValues.fromUUID(uuid); final AttributeValue uuidAttributeValue = AttributeValues.fromUUID(uuid);
@ -271,7 +276,7 @@ public class Profiles {
.keyConditionExpression("#uuid = :uuid") .keyConditionExpression("#uuid = :uuid")
.expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID))
.expressionAttributeValues(Map.of(":uuid", uuidAttributeValue)) .expressionAttributeValues(Map.of(":uuid", uuidAttributeValue))
.projectionExpression(ATTR_VERSION) .projectionExpression(String.join(", ", ATTR_VERSION, ATTR_AVATAR))
.consistentRead(true) .consistentRead(true)
.build()) .build())
.items()) .items())
@ -280,8 +285,9 @@ public class Profiles {
.key(Map.of( .key(Map.of(
KEY_ACCOUNT_UUID, uuidAttributeValue, KEY_ACCOUNT_UUID, uuidAttributeValue,
ATTR_VERSION, item.get(ATTR_VERSION))) ATTR_VERSION, item.get(ATTR_VERSION)))
.build())), MAX_CONCURRENCY) .build()))
.then() .flatMap(ignored -> Mono.justOrEmpty(item.get(ATTR_AVATAR)).map(AttributeValue::s)), MAX_CONCURRENCY)
.collectList()
.doOnSuccess(ignored -> sample.stop(Metrics.timer(DELETE_PROFILES_TIMER_NAME, "outcome", "success"))) .doOnSuccess(ignored -> sample.stop(Metrics.timer(DELETE_PROFILES_TIMER_NAME, "outcome", "success")))
.doOnError(ignored -> sample.stop(Metrics.timer(DELETE_PROFILES_TIMER_NAME, "outcome", "error"))) .doOnError(ignored -> sample.stop(Metrics.timer(DELETE_PROFILES_TIMER_NAME, "outcome", "error")))
.toFuture(); .toFuture();

View File

@ -7,17 +7,22 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import java.io.IOException; import java.io.IOException;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable; import reactor.core.publisher.Mono;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
public class ProfilesManager { public class ProfilesManager {
@ -27,13 +32,19 @@ public class ProfilesManager {
private final Profiles profiles; private final Profiles profiles;
private final FaultTolerantRedisClusterClient cacheCluster; private final FaultTolerantRedisClusterClient cacheCluster;
private final S3AsyncClient s3Client;
private final String bucket;
private final ObjectMapper mapper; private final ObjectMapper mapper;
private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];
public ProfilesManager(final Profiles profiles,
final FaultTolerantRedisClusterClient cacheCluster) { public ProfilesManager(final Profiles profiles, final FaultTolerantRedisClusterClient cacheCluster, final S3AsyncClient s3Client,
final String bucket) {
this.profiles = profiles; this.profiles = profiles;
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.s3Client = s3Client;
this.bucket = bucket;
this.mapper = SystemMapper.jsonMapper(); this.mapper = SystemMapper.jsonMapper();
} }
@ -48,7 +59,21 @@ public class ProfilesManager {
} }
public CompletableFuture<Void> deleteAll(UUID uuid) { public CompletableFuture<Void> deleteAll(UUID uuid) {
return CompletableFuture.allOf(redisDelete(uuid), profiles.deleteAll(uuid));
final CompletableFuture<Void> profilesAndAvatars = Mono.fromFuture(profiles.deleteAll(uuid))
.flatMapIterable(Function.identity())
.flatMap(avatar ->
Mono.fromFuture(s3Client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(avatar)
.build()))
// this is best-effort
.retry(3)
.onErrorComplete()
.then()
).then().toFuture();
return CompletableFuture.allOf(redisDelete(uuid), profilesAndAvatars);
} }
public Optional<VersionedProfile> get(UUID uuid, String version) { public Optional<VersionedProfile> get(UUID uuid, String version) {
@ -137,7 +162,8 @@ public class ProfilesManager {
.thenRun(Util.NOOP); .thenRun(Util.NOOP);
} }
private String getCacheKey(UUID uuid) { @VisibleForTesting
static String getCacheKey(UUID uuid) {
return CACHE_PREFIX + uuid.toString(); return CACHE_PREFIX + uuid.toString();
} }
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter;
@ -15,6 +16,7 @@ public record VersionedProfile (String version,
@JsonDeserialize(using = ByteArrayBase64WithPaddingAdapter.Deserializing.class) @JsonDeserialize(using = ByteArrayBase64WithPaddingAdapter.Deserializing.class)
byte[] name, byte[] name,
@Nullable
String avatar, String avatar,
@JsonSerialize(using = ByteArrayBase64WithPaddingAdapter.Serializing.class) @JsonSerialize(using = ByteArrayBase64WithPaddingAdapter.Serializing.class)

View File

@ -68,8 +68,10 @@ import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.s3.S3AsyncClient;
/** /**
* Construct utilities commonly used by worker commands * Construct utilities commonly used by worker commands
@ -175,6 +177,13 @@ record CommandDependencies(
DynamoDbClient dynamoDbClient = configuration.getDynamoDbClientConfiguration() DynamoDbClient dynamoDbClient = configuration.getDynamoDbClientConfiguration()
.buildSyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbSyncCommand")); .buildSyncClient(awsCredentialsProvider, new MicrometerAwsSdkMetricPublisher(awsSdkMetricsExecutor, "dynamoDbSyncCommand"));
final AwsCredentialsProvider cdnCredentialsProvider = configuration.getCdnConfiguration().credentials().build();
final S3AsyncClient asyncCdnS3Client = S3AsyncClient.builder()
.credentialsProvider(cdnCredentialsProvider)
.region(Region.of(configuration.getCdnConfiguration().region()))
.build();
RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords( RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords(
configuration.getDynamoDbTables().getRegistrationRecovery().getTableName(), configuration.getDynamoDbTables().getRegistrationRecovery().getTableName(),
configuration.getDynamoDbTables().getRegistrationRecovery().getExpiration(), configuration.getDynamoDbTables().getRegistrationRecovery().getExpiration(),
@ -222,7 +231,8 @@ record CommandDependencies(
DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, MessagesCache messagesCache = new MessagesCache(messagesCluster,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC()); messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC());
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster, asyncCdnS3Client,
configuration.getCdnConfiguration().bucket());
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient, ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(dynamoDbClient, dynamoDbAsyncClient,
configuration.getDynamoDbTables().getReportMessage().getTableName(), configuration.getDynamoDbTables().getReportMessage().getTableName(),
configuration.getReportMessageConfiguration().getReportTtl()); configuration.getReportMessageConfiguration().getReportTtl());

View File

@ -21,6 +21,7 @@ import static org.mockito.Mockito.when;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -35,6 +36,8 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class ProfilesManagerTest { public class ProfilesManagerTest {
@ -42,9 +45,12 @@ public class ProfilesManagerTest {
private Profiles profiles; private Profiles profiles;
private RedisAdvancedClusterCommands<String, String> commands; private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands; private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private S3AsyncClient s3Client;
private ProfilesManager profilesManager; private ProfilesManager profilesManager;
private static final String BUCKET = "bucket";
@BeforeEach @BeforeEach
void setUp() { void setUp() {
//noinspection unchecked //noinspection unchecked
@ -56,8 +62,9 @@ public class ProfilesManagerTest {
.build(); .build();
profiles = mock(Profiles.class); profiles = mock(Profiles.class);
s3Client = mock(S3AsyncClient.class);
profilesManager = new ProfilesManager(profiles, cacheCluster); profilesManager = new ProfilesManager(profiles, cacheCluster, s3Client, BUCKET);
} }
@Test @Test
@ -65,7 +72,7 @@ public class ProfilesManagerTest {
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
final byte[] name = TestRandomUtil.nextBytes(81); final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(uuid)).serialize(); final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(uuid)).serialize();
when(commands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(String.format( when(commands.hget(eq(ProfilesManager.getCacheKey( uuid)), eq("someversion"))).thenReturn(String.format(
"{\"version\": \"someversion\", \"name\": \"%s\", \"avatar\": \"someavatar\", \"commitment\":\"%s\"}", "{\"version\": \"someversion\", \"name\": \"%s\", \"avatar\": \"someavatar\", \"commitment\":\"%s\"}",
ProfileTestHelper.encodeToBase64(name), ProfileTestHelper.encodeToBase64(name),
ProfileTestHelper.encodeToBase64(commitment))); ProfileTestHelper.encodeToBase64(commitment)));
@ -74,10 +81,10 @@ public class ProfilesManagerTest {
assertTrue(profile.isPresent()); assertTrue(profile.isPresent());
assertArrayEquals(profile.get().name(), name); assertArrayEquals(profile.get().name(), name);
assertEquals(profile.get().avatar(), "someavatar"); assertEquals("someavatar", profile.get().avatar());
assertArrayEquals(profile.get().commitment(), commitment); assertArrayEquals(profile.get().commitment(), commitment);
verify(commands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(commands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(commands);
verifyNoMoreInteractions(profiles); verifyNoMoreInteractions(profiles);
} }
@ -88,7 +95,7 @@ public class ProfilesManagerTest {
final byte[] name = TestRandomUtil.nextBytes(81); final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(uuid)).serialize(); final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(uuid)).serialize();
when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn( when(asyncCommands.hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"))).thenReturn(
MockRedisFuture.completedFuture(String.format("{\"version\": \"someversion\", \"name\": \"%s\", \"avatar\": \"someavatar\", \"commitment\":\"%s\"}", MockRedisFuture.completedFuture(String.format("{\"version\": \"someversion\", \"name\": \"%s\", \"avatar\": \"someavatar\", \"commitment\":\"%s\"}",
ProfileTestHelper.encodeToBase64(name), ProfileTestHelper.encodeToBase64(name),
ProfileTestHelper.encodeToBase64(commitment)))); ProfileTestHelper.encodeToBase64(commitment))));
@ -97,10 +104,10 @@ public class ProfilesManagerTest {
assertTrue(profile.isPresent()); assertTrue(profile.isPresent());
assertArrayEquals(profile.get().name(), name); assertArrayEquals(profile.get().name(), name);
assertEquals(profile.get().avatar(), "someavatar"); assertEquals("someavatar", profile.get().avatar());
assertArrayEquals(profile.get().commitment(), commitment); assertArrayEquals(profile.get().commitment(), commitment);
verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(asyncCommands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncCommands);
verifyNoMoreInteractions(profiles); verifyNoMoreInteractions(profiles);
} }
@ -112,7 +119,7 @@ public class ProfilesManagerTest {
final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null, final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null,
null, null, "somecommitment".getBytes()); null, null, "somecommitment".getBytes());
when(commands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(null); when(commands.hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"))).thenReturn(null);
when(profiles.get(eq(uuid), eq("someversion"))).thenReturn(Optional.of(profile)); when(profiles.get(eq(uuid), eq("someversion"))).thenReturn(Optional.of(profile));
Optional<VersionedProfile> retrieved = profilesManager.get(uuid, "someversion"); Optional<VersionedProfile> retrieved = profilesManager.get(uuid, "someversion");
@ -120,8 +127,8 @@ public class ProfilesManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), profile); assertSame(retrieved.get(), profile);
verify(commands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(commands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verify(commands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); verify(commands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(commands);
verify(profiles, times(1)).get(eq(uuid), eq("someversion")); verify(profiles, times(1)).get(eq(uuid), eq("someversion"));
@ -135,8 +142,8 @@ public class ProfilesManagerTest {
final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null, final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null,
null, null, "somecommitment".getBytes()); null, null, "somecommitment".getBytes());
when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null));
when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
Optional<VersionedProfile> retrieved = profilesManager.getAsync(uuid, "someversion").join(); Optional<VersionedProfile> retrieved = profilesManager.getAsync(uuid, "someversion").join();
@ -144,8 +151,8 @@ public class ProfilesManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), profile); assertSame(retrieved.get(), profile);
verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(asyncCommands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); verify(asyncCommands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncCommands);
verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion")); verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion"));
@ -159,7 +166,7 @@ public class ProfilesManagerTest {
final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null, final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null,
null, null, "somecommitment".getBytes()); null, null, "somecommitment".getBytes());
when(commands.hget(eq("profiles::" + uuid), eq("someversion"))).thenThrow(new RedisException("Connection lost")); when(commands.hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"))).thenThrow(new RedisException("Connection lost"));
when(profiles.get(eq(uuid), eq("someversion"))).thenReturn(Optional.of(profile)); when(profiles.get(eq(uuid), eq("someversion"))).thenReturn(Optional.of(profile));
Optional<VersionedProfile> retrieved = profilesManager.get(uuid, "someversion"); Optional<VersionedProfile> retrieved = profilesManager.get(uuid, "someversion");
@ -167,8 +174,8 @@ public class ProfilesManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), profile); assertSame(retrieved.get(), profile);
verify(commands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(commands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verify(commands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); verify(commands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(commands);
verify(profiles, times(1)).get(eq(uuid), eq("someversion")); verify(profiles, times(1)).get(eq(uuid), eq("someversion"));
@ -182,8 +189,8 @@ public class ProfilesManagerTest {
final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null, final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null,
null, null, "somecommitment".getBytes()); null, null, "somecommitment".getBytes());
when(asyncCommands.hget(eq("profiles::" + uuid), eq("someversion"))).thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost"))); when(asyncCommands.hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"))).thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost")));
when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null));
when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile))); when(profiles.getAsync(eq(uuid), eq("someversion"))).thenReturn(CompletableFuture.completedFuture(Optional.of(profile)));
Optional<VersionedProfile> retrieved = profilesManager.getAsync(uuid, "someversion").join(); Optional<VersionedProfile> retrieved = profilesManager.getAsync(uuid, "someversion").join();
@ -191,8 +198,8 @@ public class ProfilesManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), profile); assertSame(retrieved.get(), profile);
verify(asyncCommands, times(1)).hget(eq("profiles::" + uuid), eq("someversion")); verify(asyncCommands, times(1)).hget(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"));
verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), anyString()); verify(asyncCommands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncCommands);
verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion")); verify(profiles, times(1)).getAsync(eq(uuid), eq("someversion"));
@ -208,7 +215,7 @@ public class ProfilesManagerTest {
profilesManager.set(uuid, profile); profilesManager.set(uuid, profile);
verify(commands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), any()); verify(commands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), any());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(commands);
verify(profiles, times(1)).set(eq(uuid), eq(profile)); verify(profiles, times(1)).set(eq(uuid), eq(profile));
@ -222,15 +229,39 @@ public class ProfilesManagerTest {
final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null, final VersionedProfile profile = new VersionedProfile("someversion", name, "someavatar", null, null,
null, null, "somecommitment".getBytes()); null, null, "somecommitment".getBytes());
when(asyncCommands.hset(eq("profiles::" + uuid), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), anyString())).thenReturn(MockRedisFuture.completedFuture(null));
when(profiles.setAsync(eq(uuid), eq(profile))).thenReturn(CompletableFuture.completedFuture(null)); when(profiles.setAsync(eq(uuid), eq(profile))).thenReturn(CompletableFuture.completedFuture(null));
profilesManager.setAsync(uuid, profile).join(); profilesManager.setAsync(uuid, profile).join();
verify(asyncCommands, times(1)).hset(eq("profiles::" + uuid), eq("someversion"), any()); verify(asyncCommands, times(1)).hset(eq(ProfilesManager.getCacheKey(uuid)), eq("someversion"), any());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncCommands);
verify(profiles, times(1)).setAsync(eq(uuid), eq(profile)); verify(profiles, times(1)).setAsync(eq(uuid), eq(profile));
verifyNoMoreInteractions(profiles); verifyNoMoreInteractions(profiles);
} }
@Test
public void testDeleteAll() {
final UUID uuid = UUID.randomUUID();
final String avatarOne = "avatar1";
final String avatarTwo = "avatar2";
when(profiles.deleteAll(uuid)).thenReturn(CompletableFuture.completedFuture(List.of(avatarOne, avatarTwo)));
when(asyncCommands.del(ProfilesManager.getCacheKey(uuid))).thenReturn(MockRedisFuture.completedFuture(null));
when(s3Client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
profilesManager.deleteAll(uuid).join();
verify(profiles).deleteAll(uuid);
verify(asyncCommands).del(ProfilesManager.getCacheKey(uuid));
verify(s3Client).deleteObject(DeleteObjectRequest.builder()
.bucket(BUCKET)
.key(avatarOne)
.build());
verify(s3Client).deleteObject(DeleteObjectRequest.builder()
.bucket(BUCKET)
.key(avatarTwo)
.build());
}
} }

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -226,6 +227,7 @@ public class ProfilesTest {
void testDelete() throws InvalidInputException { void testDelete() throws InvalidInputException {
final String versionOne = "versionOne"; final String versionOne = "versionOne";
final String versionTwo = "versionTwo"; final String versionTwo = "versionTwo";
final String versionThree = "versionThree";
final byte[] nameOne = TestRandomUtil.nextBytes(81); final byte[] nameOne = TestRandomUtil.nextBytes(81);
final byte[] nameTwo = TestRandomUtil.nextBytes(81); final byte[] nameTwo = TestRandomUtil.nextBytes(81);
@ -238,23 +240,27 @@ public class ProfilesTest {
final byte[] commitmentOne = new ProfileKey(TestRandomUtil.nextBytes(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); final byte[] commitmentOne = new ProfileKey(TestRandomUtil.nextBytes(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
final byte[] commitmentTwo = new ProfileKey(TestRandomUtil.nextBytes(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); final byte[] commitmentTwo = new ProfileKey(TestRandomUtil.nextBytes(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
final byte[] commitmentThree = new ProfileKey(TestRandomUtil.nextBytes(32)).getCommitment(new ServiceId.Aci(ACI)).serialize();
VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null, VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null,
null, null, commitmentOne); null, null, commitmentOne);
VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, null, commitmentTwo); VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, null, commitmentTwo);
VersionedProfile profileThree = new VersionedProfile(versionThree, nameTwo, null, aboutEmoji, about, null, null,
commitmentThree);
profiles.set(ACI, profileOne); profiles.set(ACI, profileOne);
profiles.set(ACI, profileTwo); profiles.set(ACI, profileTwo);
profiles.set(ACI, profileThree);
profiles.deleteAll(ACI).join(); final List<String> avatars = profiles.deleteAll(ACI).join();
Optional<VersionedProfile> retrieved = profiles.get(ACI, versionOne); for (String version : List.of(versionOne, versionTwo, versionThree)) {
final Optional<VersionedProfile> retrieved = profiles.get(ACI, version);
assertThat(retrieved.isPresent()).isFalse();
}
assertThat(retrieved.isPresent()).isFalse(); assertThat(avatars.size()).isEqualTo(2);
assertThat(avatars.containsAll(List.of(avatarOne, avatarTwo))).isTrue();
retrieved = profiles.get(ACI, versionTwo);
assertThat(retrieved.isPresent()).isFalse();
} }
@ParameterizedTest @ParameterizedTest
@ -266,7 +272,7 @@ public class ProfilesTest {
private static Stream<Arguments> buildUpdateExpression() throws InvalidInputException { private static Stream<Arguments> buildUpdateExpression() throws InvalidInputException {
final String version = "someVersion"; final String version = "someVersion";
final byte[] name = TestRandomUtil.nextBytes(81); final byte[] name = TestRandomUtil.nextBytes(81);
final String avatar = "profiles/" + ProfileTestHelper.generateRandomBase64FromByteArray(16);; final String avatar = "profiles/" + ProfileTestHelper.generateRandomBase64FromByteArray(16);
final byte[] emoji = TestRandomUtil.nextBytes(60); final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156); final byte[] about = TestRandomUtil.nextBytes(156);
final byte[] paymentAddress = TestRandomUtil.nextBytes(582); final byte[] paymentAddress = TestRandomUtil.nextBytes(582);
@ -313,7 +319,7 @@ public class ProfilesTest {
private static Stream<Arguments> buildUpdateExpressionAttributeValues() throws InvalidInputException { private static Stream<Arguments> buildUpdateExpressionAttributeValues() throws InvalidInputException {
final String version = "someVersion"; final String version = "someVersion";
final byte[] name = TestRandomUtil.nextBytes(81); final byte[] name = TestRandomUtil.nextBytes(81);
final String avatar = "profiles/" + ProfileTestHelper.generateRandomBase64FromByteArray(16);; final String avatar = "profiles/" + ProfileTestHelper.generateRandomBase64FromByteArray(16);
final byte[] emoji = TestRandomUtil.nextBytes(60); final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156); final byte[] about = TestRandomUtil.nextBytes(156);
final byte[] paymentAddress = TestRandomUtil.nextBytes(582); final byte[] paymentAddress = TestRandomUtil.nextBytes(582);