When persisting messages fails due to a full queue in DynamoDB, automatically unlink one device to free up room.
Co-authored-by: Chris Eager <79161849+eager-signal@users.noreply.github.com>
This commit is contained in:
parent
ce60f13320
commit
8f7bae54fe
|
@ -16,17 +16,30 @@ import com.codahale.metrics.Timer;
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import io.dropwizard.lifecycle.Managed;
|
import io.dropwizard.lifecycle.Managed;
|
||||||
import io.micrometer.core.instrument.Counter;
|
import io.micrometer.core.instrument.Counter;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
import reactor.util.function.Tuple2;
|
||||||
|
import reactor.util.function.Tuples;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
|
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||||
import org.whispersystems.textsecuregcm.util.Constants;
|
import org.whispersystems.textsecuregcm.util.Constants;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
|
@ -35,6 +48,8 @@ public class MessagePersister implements Managed {
|
||||||
private final MessagesCache messagesCache;
|
private final MessagesCache messagesCache;
|
||||||
private final MessagesManager messagesManager;
|
private final MessagesManager messagesManager;
|
||||||
private final AccountsManager accountsManager;
|
private final AccountsManager accountsManager;
|
||||||
|
private final ClientPresenceManager clientPresenceManager;
|
||||||
|
private final KeysManager keysManager;
|
||||||
|
|
||||||
private final Duration persistDelay;
|
private final Duration persistDelay;
|
||||||
|
|
||||||
|
@ -50,27 +65,35 @@ public class MessagePersister implements Managed {
|
||||||
private final Counter oversizedQueueCounter = counter(name(MessagePersister.class, "persistQueueOversized"));
|
private final Counter oversizedQueueCounter = counter(name(MessagePersister.class, "persistQueueOversized"));
|
||||||
private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueCount"));
|
private final Histogram queueCountHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueCount"));
|
||||||
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueSize"));
|
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(MessagePersister.class, "queueSize"));
|
||||||
|
private final ExecutorService executor;
|
||||||
|
|
||||||
static final int QUEUE_BATCH_LIMIT = 100;
|
static final int QUEUE_BATCH_LIMIT = 100;
|
||||||
static final int MESSAGE_BATCH_LIMIT = 100;
|
static final int MESSAGE_BATCH_LIMIT = 100;
|
||||||
|
|
||||||
private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis();
|
private static final long EXCEPTION_PAUSE_MILLIS = Duration.ofSeconds(3).toMillis();
|
||||||
|
public static final Duration UNLINK_TIMEOUT = Duration.ofHours(1);
|
||||||
|
|
||||||
private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3;
|
private static final int CONSECUTIVE_EMPTY_CACHE_REMOVAL_LIMIT = 3;
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
|
private static final Logger logger = LoggerFactory.getLogger(MessagePersister.class);
|
||||||
|
|
||||||
public MessagePersister(final MessagesCache messagesCache, final MessagesManager messagesManager,
|
public MessagePersister(final MessagesCache messagesCache, final MessagesManager messagesManager,
|
||||||
final AccountsManager accountsManager,
|
final AccountsManager accountsManager, final ClientPresenceManager clientPresenceManager,
|
||||||
|
final KeysManager keysManager,
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final Duration persistDelay,
|
final Duration persistDelay,
|
||||||
final int dedicatedProcessWorkerThreadCount) {
|
final int dedicatedProcessWorkerThreadCount,
|
||||||
|
final ExecutorService executor
|
||||||
|
) {
|
||||||
this.messagesCache = messagesCache;
|
this.messagesCache = messagesCache;
|
||||||
this.messagesManager = messagesManager;
|
this.messagesManager = messagesManager;
|
||||||
this.accountsManager = accountsManager;
|
this.accountsManager = accountsManager;
|
||||||
|
this.clientPresenceManager = clientPresenceManager;
|
||||||
|
this.keysManager = keysManager;
|
||||||
this.persistDelay = persistDelay;
|
this.persistDelay = persistDelay;
|
||||||
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
|
this.workerThreads = new Thread[dedicatedProcessWorkerThreadCount];
|
||||||
this.dedicatedProcess = true;
|
this.dedicatedProcess = true;
|
||||||
|
this.executor = executor;
|
||||||
|
|
||||||
for (int i = 0; i < workerThreads.length; i++) {
|
for (int i = 0; i < workerThreads.length; i++) {
|
||||||
workerThreads[i] = new Thread(() -> {
|
workerThreads[i] = new Thread(() -> {
|
||||||
|
@ -139,12 +162,14 @@ public class MessagePersister implements Managed {
|
||||||
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue);
|
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue);
|
||||||
final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queue);
|
final byte deviceId = MessagesCache.getDeviceIdFromQueueName(queue);
|
||||||
|
|
||||||
|
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(accountUuid);
|
||||||
|
if (maybeAccount.isEmpty()) {
|
||||||
|
logger.error("No account record found for account {}", accountUuid);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
persistQueue(accountUuid, deviceId);
|
persistQueue(maybeAccount.get(), deviceId);
|
||||||
} catch (final Exception e) {
|
} catch (final Exception e) {
|
||||||
if (e instanceof ItemCollectionSizeLimitExceededException) {
|
|
||||||
oversizedQueueCounter.increment();
|
|
||||||
}
|
|
||||||
persistQueueExceptionMeter.mark();
|
persistQueueExceptionMeter.mark();
|
||||||
logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e);
|
logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId, e);
|
||||||
|
|
||||||
|
@ -161,14 +186,8 @@ public class MessagePersister implements Managed {
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
void persistQueue(final UUID accountUuid, final byte deviceId) throws MessagePersistenceException {
|
void persistQueue(final Account account, final byte deviceId) throws MessagePersistenceException {
|
||||||
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(accountUuid);
|
final UUID accountUuid = account.getUuid();
|
||||||
|
|
||||||
if (maybeAccount.isEmpty()) {
|
|
||||||
logger.error("No account record found for account {}", accountUuid);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try (final Timer.Context ignored = persistQueueTimer.time()) {
|
try (final Timer.Context ignored = persistQueueTimer.time()) {
|
||||||
messagesCache.lockQueueForPersistence(accountUuid, deviceId);
|
messagesCache.lockQueueForPersistence(accountUuid, deviceId);
|
||||||
|
|
||||||
|
@ -197,9 +216,73 @@ public class MessagePersister implements Managed {
|
||||||
} while (!messages.isEmpty());
|
} while (!messages.isEmpty());
|
||||||
|
|
||||||
queueSizeHistogram.update(messageCount);
|
queueSizeHistogram.update(messageCount);
|
||||||
|
} catch (ItemCollectionSizeLimitExceededException e) {
|
||||||
|
oversizedQueueCounter.increment();
|
||||||
|
unlinkLeastActiveDevice(account, deviceId); // this will either do a deferred reschedule for retry or throw
|
||||||
} finally {
|
} finally {
|
||||||
messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
|
messagesCache.unlockQueueForPersistence(accountUuid, deviceId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
void unlinkLeastActiveDevice(final Account account, byte destinationDeviceId) throws MessagePersistenceException {
|
||||||
|
if (!messagesCache.lockAccountForMessagePersisterCleanup(account.getUuid())) {
|
||||||
|
// don't try to unlink an account multiple times in parallel; just fail this persist attempt
|
||||||
|
// and we'll try again, hopefully deletions in progress will have made room
|
||||||
|
throw new MessagePersistenceException("account has a full queue and another device-unlinking attempt is in progress");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Selection logic:
|
||||||
|
|
||||||
|
// The primary device is never unlinked
|
||||||
|
// The least-recently-seen inactive device gets unlinked
|
||||||
|
// If there are none, the device with the oldest queued message (not necessarily the
|
||||||
|
// least-recently-seen; a device could be connecting frequently but have some problem fetching
|
||||||
|
// its messages) is unlinked
|
||||||
|
final Device deviceToDelete = account.getDevices()
|
||||||
|
.stream()
|
||||||
|
.filter(d -> !d.isPrimary() && !d.isEnabled())
|
||||||
|
.min(Comparator.comparing(Device::getLastSeen))
|
||||||
|
.or(() ->
|
||||||
|
Flux.fromIterable(account.getDevices())
|
||||||
|
.filter(d -> !d.isPrimary())
|
||||||
|
.flatMap(d ->
|
||||||
|
messagesManager
|
||||||
|
.getEarliestUndeliveredTimestampForDevice(account.getUuid(), d.getId())
|
||||||
|
.map(t -> Tuples.of(d, t)))
|
||||||
|
.sort(Comparator.comparing(Tuple2::getT2))
|
||||||
|
.map(Tuple2::getT1)
|
||||||
|
.next()
|
||||||
|
.blockOptional())
|
||||||
|
.orElseThrow(() -> new MessagePersistenceException("account has a full queue and no unlinkable devices"));
|
||||||
|
|
||||||
|
logger.warn("Failed to persist queue {}::{} due to overfull queue; will unlink device {}{}",
|
||||||
|
account.getUuid(), destinationDeviceId, deviceToDelete.getId(), deviceToDelete.getId() == destinationDeviceId ? "" : " and schedule for retry");
|
||||||
|
executor.execute(
|
||||||
|
() -> {
|
||||||
|
try {
|
||||||
|
// Lock the device's auth token first to prevent it from connecting while we're
|
||||||
|
// destroying its queue, but we don't want to completely remove the device from the
|
||||||
|
// account until we're sure the messages have been cleared because otherwise we won't
|
||||||
|
// be able to find it later to try again, in the event of a failure or timeout
|
||||||
|
final Account updatedAccount = accountsManager.update(
|
||||||
|
account,
|
||||||
|
a -> a.getDevice(deviceToDelete.getId()).ifPresent(Device::lockAuthTokenHash));
|
||||||
|
clientPresenceManager.disconnectPresence(account.getUuid(), deviceToDelete.getId());
|
||||||
|
CompletableFuture
|
||||||
|
.allOf(
|
||||||
|
keysManager.delete(account.getUuid(), deviceToDelete.getId()),
|
||||||
|
messagesManager.clear(account.getUuid(), deviceToDelete.getId()))
|
||||||
|
.orTimeout((UNLINK_TIMEOUT.toSeconds() * 3) / 4, TimeUnit.SECONDS)
|
||||||
|
.join();
|
||||||
|
accountsManager.update(updatedAccount, a -> a.removeDevice(deviceToDelete.getId()));
|
||||||
|
} finally {
|
||||||
|
messagesCache.unlockAccountForMessagePersisterCleanup(account.getUuid());
|
||||||
|
if (deviceToDelete.getId() != destinationDeviceId) { // no point in persisting a device we just purged
|
||||||
|
messagesCache.addQueueToPersist(account.getUuid(), destinationDeviceId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
import io.dropwizard.lifecycle.Managed;
|
import io.dropwizard.lifecycle.Managed;
|
||||||
import io.lettuce.core.ScoredValue;
|
import io.lettuce.core.ScoredValue;
|
||||||
import io.lettuce.core.ScriptOutputType;
|
import io.lettuce.core.ScriptOutputType;
|
||||||
|
import io.lettuce.core.SetArgs;
|
||||||
import io.lettuce.core.ZAddArgs;
|
import io.lettuce.core.ZAddArgs;
|
||||||
import io.lettuce.core.cluster.SlotHash;
|
import io.lettuce.core.cluster.SlotHash;
|
||||||
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
|
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
|
||||||
|
@ -382,6 +383,20 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
|
||||||
connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId)));
|
connection -> connection.sync().del(getPersistInProgressKey(accountUuid, deviceId)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean lockAccountForMessagePersisterCleanup(final UUID accountUuid) {
|
||||||
|
return readDeleteCluster.withBinaryCluster(
|
||||||
|
connection -> "OK".equals(
|
||||||
|
connection.sync().set(
|
||||||
|
getUnlinkInProgressKey(accountUuid),
|
||||||
|
LOCK_VALUE,
|
||||||
|
new SetArgs().ex(MessagePersister.UNLINK_TIMEOUT.toSeconds()).nx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
void unlockAccountForMessagePersisterCleanup(final UUID accountUuid) {
|
||||||
|
readDeleteCluster.useBinaryCluster(
|
||||||
|
connection -> connection.sync().del(getUnlinkInProgressKey(accountUuid)));
|
||||||
|
}
|
||||||
|
|
||||||
public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId,
|
public void addMessageAvailabilityListener(final UUID destinationUuid, final byte deviceId,
|
||||||
final MessageAvailabilityListener listener) {
|
final MessageAvailabilityListener listener) {
|
||||||
final String queueName = getQueueName(destinationUuid, deviceId);
|
final String queueName = getQueueName(destinationUuid, deviceId);
|
||||||
|
@ -531,6 +546,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
|
||||||
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
|
return ("user_queue_persisting::{" + accountUuid + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static byte[] getUnlinkInProgressKey(final UUID accountUuid) {
|
||||||
|
return ("user_account_unlinking::{" + accountUuid + "}").getBytes(StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
|
||||||
static UUID getAccountUuidFromQueueName(final String queueName) {
|
static UUID getAccountUuidFromQueueName(final String queueName) {
|
||||||
final int startOfHashTag = queueName.indexOf('{');
|
final int startOfHashTag = queueName.indexOf('{');
|
||||||
|
|
||||||
|
|
|
@ -102,6 +102,10 @@ public class MessagesManager {
|
||||||
.tap(Micrometer.metrics(Metrics.globalRegistry));
|
.tap(Micrometer.metrics(Metrics.globalRegistry));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Mono<Long> getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, byte destinationDevice) {
|
||||||
|
return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp);
|
||||||
|
}
|
||||||
|
|
||||||
public CompletableFuture<Void> clear(UUID destinationUuid) {
|
public CompletableFuture<Void> clear(UUID destinationUuid) {
|
||||||
return CompletableFuture.allOf(
|
return CompletableFuture.allOf(
|
||||||
messagesCache.clear(destinationUuid),
|
messagesCache.clear(destinationUuid),
|
||||||
|
|
|
@ -62,9 +62,14 @@ public class MessagePersisterServiceCommand extends ServerCommand<WhisperServerC
|
||||||
|
|
||||||
final MessagePersister messagePersister = new MessagePersister(deps.messagesCache(), deps.messagesManager(),
|
final MessagePersister messagePersister = new MessagePersister(deps.messagesCache(), deps.messagesManager(),
|
||||||
deps.accountsManager(),
|
deps.accountsManager(),
|
||||||
|
deps.clientPresenceManager(),
|
||||||
|
deps.keysManager(),
|
||||||
dynamicConfigurationManager,
|
dynamicConfigurationManager,
|
||||||
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
|
Duration.ofMinutes(configuration.getMessageCacheConfiguration().getPersistDelayMinutes()),
|
||||||
namespace.getInt(WORKER_COUNT));
|
namespace.getInt(WORKER_COUNT),
|
||||||
|
environment.lifecycle().executorService("messagePersisterUnlinkDeviceExecutor-%d")
|
||||||
|
.maxThreads(2)
|
||||||
|
.build());
|
||||||
|
|
||||||
environment.lifecycle().manage(deps.messagesCache());
|
environment.lifecycle().manage(deps.messagesCache());
|
||||||
environment.lifecycle().manage(messagePersister);
|
environment.lifecycle().manage(messagePersister);
|
||||||
|
|
|
@ -32,7 +32,9 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
|
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
|
||||||
import reactor.core.scheduler.Scheduler;
|
import reactor.core.scheduler.Scheduler;
|
||||||
import reactor.core.scheduler.Schedulers;
|
import reactor.core.scheduler.Schedulers;
|
||||||
|
@ -83,7 +85,8 @@ class MessagePersisterIntegrationTest {
|
||||||
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
|
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class),
|
||||||
messageDeletionExecutorService);
|
messageDeletionExecutorService);
|
||||||
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
|
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
|
||||||
dynamicConfigurationManager, PERSIST_DELAY, 1);
|
mock(ClientPresenceManager.class), mock(KeysManager.class), dynamicConfigurationManager, PERSIST_DELAY, 1,
|
||||||
|
Executors.newSingleThreadExecutor());
|
||||||
|
|
||||||
account = mock(Account.class);
|
account = mock(Account.class);
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
package org.whispersystems.textsecuregcm.storage;
|
package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
|
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
@ -16,9 +17,13 @@ import static org.mockito.Mockito.atLeastOnce;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.reset;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
|
||||||
|
|
||||||
|
import com.google.common.util.concurrent.MoreExecutors;
|
||||||
import com.google.protobuf.ByteString;
|
import com.google.protobuf.ByteString;
|
||||||
import io.lettuce.core.cluster.SlotHash;
|
import io.lettuce.core.cluster.SlotHash;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
@ -28,6 +33,7 @@ import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
@ -41,9 +47,14 @@ import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.stubbing.Answer;
|
import org.mockito.stubbing.Answer;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
import org.whispersystems.textsecuregcm.entities.MessageProtos;
|
||||||
|
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||||
|
|
||||||
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.core.scheduler.Scheduler;
|
import reactor.core.scheduler.Scheduler;
|
||||||
import reactor.core.scheduler.Schedulers;
|
import reactor.core.scheduler.Schedulers;
|
||||||
|
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
|
||||||
|
|
||||||
class MessagePersisterTest {
|
class MessagePersisterTest {
|
||||||
|
|
||||||
|
@ -57,7 +68,10 @@ class MessagePersisterTest {
|
||||||
private MessagesDynamoDb messagesDynamoDb;
|
private MessagesDynamoDb messagesDynamoDb;
|
||||||
private MessagePersister messagePersister;
|
private MessagePersister messagePersister;
|
||||||
private AccountsManager accountsManager;
|
private AccountsManager accountsManager;
|
||||||
|
private ClientPresenceManager clientPresenceManager;
|
||||||
|
private KeysManager keysManager;
|
||||||
private MessagesManager messagesManager;
|
private MessagesManager messagesManager;
|
||||||
|
private Account destinationAccount;
|
||||||
|
|
||||||
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
|
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
|
||||||
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
|
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
|
||||||
|
@ -74,11 +88,13 @@ class MessagePersisterTest {
|
||||||
|
|
||||||
messagesDynamoDb = mock(MessagesDynamoDb.class);
|
messagesDynamoDb = mock(MessagesDynamoDb.class);
|
||||||
accountsManager = mock(AccountsManager.class);
|
accountsManager = mock(AccountsManager.class);
|
||||||
|
clientPresenceManager = mock(ClientPresenceManager.class);
|
||||||
|
keysManager = mock(KeysManager.class);
|
||||||
|
destinationAccount = mock(Account.class);;
|
||||||
|
|
||||||
final Account account = mock(Account.class);
|
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount));
|
||||||
|
when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID);
|
||||||
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
|
when(destinationAccount.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
|
||||||
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
|
|
||||||
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
|
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
|
||||||
|
|
||||||
sharedExecutorService = Executors.newSingleThreadExecutor();
|
sharedExecutorService = Executors.newSingleThreadExecutor();
|
||||||
|
@ -87,8 +103,8 @@ class MessagePersisterTest {
|
||||||
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, messageDeliveryScheduler,
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(), sharedExecutorService, messageDeliveryScheduler,
|
||||||
sharedExecutorService, Clock.systemUTC());
|
sharedExecutorService, Clock.systemUTC());
|
||||||
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
|
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, clientPresenceManager,
|
||||||
dynamicConfigurationManager, PERSIST_DELAY, 1);
|
keysManager, dynamicConfigurationManager, PERSIST_DELAY, 1, MoreExecutors.newDirectExecutorService());
|
||||||
|
|
||||||
when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> {
|
when(messagesManager.persistMessages(any(UUID.class), anyByte(), any())).thenAnswer(invocation -> {
|
||||||
final UUID destinationUuid = invocation.getArgument(0);
|
final UUID destinationUuid = invocation.getArgument(0);
|
||||||
|
@ -172,6 +188,7 @@ class MessagePersisterTest {
|
||||||
final Account account = mock(Account.class);
|
final Account account = mock(Account.class);
|
||||||
|
|
||||||
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
|
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
|
||||||
|
when(account.getUuid()).thenReturn(accountUuid);
|
||||||
when(account.getNumber()).thenReturn(accountNumber);
|
when(account.getNumber()).thenReturn(accountNumber);
|
||||||
|
|
||||||
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
|
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
|
||||||
|
@ -223,7 +240,150 @@ class MessagePersisterTest {
|
||||||
|
|
||||||
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
|
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
|
||||||
assertThrows(MessagePersistenceException.class,
|
assertThrows(MessagePersistenceException.class,
|
||||||
() -> messagePersister.persistQueue(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID)));
|
() -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testUnlinkFirstInactiveDeviceOnFullQueue() {
|
||||||
|
final String queueName = new String(
|
||||||
|
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||||
|
final int messageCount = 1;
|
||||||
|
final Instant now = Instant.now();
|
||||||
|
|
||||||
|
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
|
||||||
|
setNextSlotToPersist(SlotHash.getSlot(queueName));
|
||||||
|
|
||||||
|
final Device primary = mock(Device.class);
|
||||||
|
when(primary.getId()).thenReturn((byte) 1);
|
||||||
|
when(primary.isPrimary()).thenReturn(true);
|
||||||
|
when(primary.isEnabled()).thenReturn(true);
|
||||||
|
final Device activeA = mock(Device.class);
|
||||||
|
when(activeA.getId()).thenReturn((byte) 2);
|
||||||
|
when(activeA.isEnabled()).thenReturn(true);
|
||||||
|
final Device inactiveB = mock(Device.class);
|
||||||
|
final byte inactiveId = 3;
|
||||||
|
when(inactiveB.getId()).thenReturn(inactiveId);
|
||||||
|
when(inactiveB.isEnabled()).thenReturn(false);
|
||||||
|
final Device inactiveC = mock(Device.class);
|
||||||
|
when(inactiveC.getId()).thenReturn((byte) 4);
|
||||||
|
when(inactiveC.isEnabled()).thenReturn(false);
|
||||||
|
final Device activeD = mock(Device.class);
|
||||||
|
when(activeD.getId()).thenReturn((byte) 5);
|
||||||
|
when(activeD.isEnabled()).thenReturn(true);
|
||||||
|
final Device destination = mock(Device.class);
|
||||||
|
when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID);
|
||||||
|
when(destination.isEnabled()).thenReturn(true);
|
||||||
|
|
||||||
|
when(destinationAccount.getDevices()).thenReturn(List.of(primary, activeA, inactiveB, inactiveC, activeD, destination));
|
||||||
|
|
||||||
|
when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
|
||||||
|
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
when(keysManager.delete(any(), eq(inactiveId))).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
|
||||||
|
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
|
||||||
|
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID));
|
||||||
|
|
||||||
|
verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, inactiveId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testUnlinkActiveDeviceWithOldestMessageOnFullQueueWithNoInactiveDevices() {
|
||||||
|
final String queueName = new String(
|
||||||
|
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||||
|
final int messageCount = 1;
|
||||||
|
final Instant now = Instant.now();
|
||||||
|
|
||||||
|
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
|
||||||
|
setNextSlotToPersist(SlotHash.getSlot(queueName));
|
||||||
|
|
||||||
|
final Device primary = mock(Device.class);
|
||||||
|
final byte primaryId = 1;
|
||||||
|
when(primary.getId()).thenReturn(primaryId);
|
||||||
|
when(primary.isPrimary()).thenReturn(true);
|
||||||
|
when(primary.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId)))
|
||||||
|
.thenReturn(Mono.just(4L));
|
||||||
|
|
||||||
|
final Device deviceA = mock(Device.class);
|
||||||
|
final byte deviceIdA = 2;
|
||||||
|
when(deviceA.getId()).thenReturn(deviceIdA);
|
||||||
|
when(deviceA.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA)))
|
||||||
|
.thenReturn(Mono.empty());
|
||||||
|
|
||||||
|
final Device deviceB = mock(Device.class);
|
||||||
|
final byte deviceIdB = 3;
|
||||||
|
when(deviceB.getId()).thenReturn(deviceIdB);
|
||||||
|
when(deviceB.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB)))
|
||||||
|
.thenReturn(Mono.just(2L));
|
||||||
|
|
||||||
|
final Device destination = mock(Device.class);
|
||||||
|
when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID);
|
||||||
|
when(destination.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID)))
|
||||||
|
.thenReturn(Mono.just(5L));
|
||||||
|
|
||||||
|
when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination));
|
||||||
|
|
||||||
|
when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
|
||||||
|
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
when(keysManager.delete(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
|
||||||
|
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
|
||||||
|
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID));
|
||||||
|
|
||||||
|
verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, deviceIdB);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testUnlinkDestinationDevice() {
|
||||||
|
final String queueName = new String(
|
||||||
|
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
|
||||||
|
final int messageCount = 1;
|
||||||
|
final Instant now = Instant.now();
|
||||||
|
|
||||||
|
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
|
||||||
|
setNextSlotToPersist(SlotHash.getSlot(queueName));
|
||||||
|
|
||||||
|
final Device primary = mock(Device.class);
|
||||||
|
final byte primaryId = 1;
|
||||||
|
when(primary.getId()).thenReturn(primaryId);
|
||||||
|
when(primary.isPrimary()).thenReturn(true);
|
||||||
|
when(primary.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(primaryId)))
|
||||||
|
.thenReturn(Mono.just(1L));
|
||||||
|
|
||||||
|
final Device deviceA = mock(Device.class);
|
||||||
|
final byte deviceIdA = 2;
|
||||||
|
when(deviceA.getId()).thenReturn(deviceIdA);
|
||||||
|
when(deviceA.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdA)))
|
||||||
|
.thenReturn(Mono.just(3L));
|
||||||
|
|
||||||
|
final Device deviceB = mock(Device.class);
|
||||||
|
final byte deviceIdB = 2;
|
||||||
|
when(deviceB.getId()).thenReturn(deviceIdB);
|
||||||
|
when(deviceB.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(deviceIdB)))
|
||||||
|
.thenReturn(Mono.empty());
|
||||||
|
|
||||||
|
final Device destination = mock(Device.class);
|
||||||
|
when(destination.getId()).thenReturn(DESTINATION_DEVICE_ID);
|
||||||
|
when(destination.isEnabled()).thenReturn(true);
|
||||||
|
when(messagesManager.getEarliestUndeliveredTimestampForDevice(any(), eq(DESTINATION_DEVICE_ID)))
|
||||||
|
.thenReturn(Mono.just(2L));
|
||||||
|
|
||||||
|
when(destinationAccount.getDevices()).thenReturn(List.of(primary, deviceA, deviceB, destination));
|
||||||
|
|
||||||
|
when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build());
|
||||||
|
when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
|
||||||
|
assertTimeoutPreemptively(Duration.ofSeconds(1), () ->
|
||||||
|
messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID));
|
||||||
|
|
||||||
|
verify(messagesManager, exactly()).clear(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("SameParameterValue")
|
@SuppressWarnings("SameParameterValue")
|
||||||
|
@ -265,5 +425,4 @@ class MessagePersisterTest {
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
|
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
|
||||||
connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
|
connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,12 +9,23 @@ import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.mockito.Mockito.doNothing;
|
import static org.mockito.Mockito.doNothing;
|
||||||
import static org.mockito.Mockito.doReturn;
|
import static org.mockito.Mockito.doReturn;
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.internal.exceptions.Reporter.noMoreInteractionsWanted;
|
||||||
|
import static org.mockito.internal.exceptions.Reporter.wantedButNotInvoked;
|
||||||
|
import static org.mockito.internal.invocation.InvocationMarker.markVerified;
|
||||||
|
import static org.mockito.internal.invocation.InvocationsFinder.findFirstUnverified;
|
||||||
|
import static org.mockito.internal.invocation.InvocationsFinder.findInvocations;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.mockito.Mockito;
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.invocation.Invocation;
|
||||||
|
import org.mockito.invocation.MatchableInvocation;
|
||||||
|
import org.mockito.verification.VerificationMode;
|
||||||
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
|
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
|
@ -154,4 +165,30 @@ public final class MockUtils {
|
||||||
}
|
}
|
||||||
return new SecretBytes(bytes);
|
return new SecretBytes(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* modeled after {@link org.mockito.Mockito#only()}, verifies that the matched invocation is the only invocation of
|
||||||
|
* this method
|
||||||
|
*/
|
||||||
|
public static VerificationMode exactly() {
|
||||||
|
return data -> {
|
||||||
|
MatchableInvocation target = data.getTarget();
|
||||||
|
final List<Invocation> allInvocations = data.getAllInvocations();
|
||||||
|
List<Invocation> chunk = findInvocations(allInvocations, target);
|
||||||
|
List<Invocation> otherInvocations = allInvocations.stream()
|
||||||
|
.filter(target::hasSameMethod)
|
||||||
|
.filter(Predicate.not(target::matches))
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
if (!otherInvocations.isEmpty()) {
|
||||||
|
Invocation unverified = findFirstUnverified(otherInvocations);
|
||||||
|
throw noMoreInteractionsWanted(unverified, (List) allInvocations);
|
||||||
|
}
|
||||||
|
if (chunk.isEmpty()) {
|
||||||
|
throw wantedButNotInvoked(target);
|
||||||
|
}
|
||||||
|
markVerified(chunk.get(0), target);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue