diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index ba034fad8..f15c564b2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -94,7 +94,15 @@ public class MessagePersister implements Managed { } for (final String queue : queuesToPersist) { - persistQueue(queue); + final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue); + final long deviceId = MessagesCache.getDeviceIdFromQueueName(queue); + + try { + persistQueue(accountUuid, deviceId); + } catch (final Exception e) { + logger.warn("Failed to persist queue {}::{}; will schedule for retry", accountUuid, deviceId); + messagesCache.addQueueToPersist(accountUuid, deviceId); + } } queuesPersisted += queuesToPersist.size(); @@ -104,10 +112,7 @@ public class MessagePersister implements Managed { } @VisibleForTesting - void persistQueue(final String queue) { - final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queue); - final long deviceId = MessagesCache.getDeviceIdFromQueueName(queue); - + void persistQueue(final UUID accountUuid, final long deviceId) { final Optional maybeAccount = accountsManager.get(accountUuid); final String accountNumber; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 47068953f..de01ca470 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -5,6 +5,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.lifecycle.Managed; import io.lettuce.core.ScoredValue; import io.lettuce.core.ScriptOutputType; +import io.lettuce.core.ZAddArgs; import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.models.partitions.RedisClusterNode; @@ -322,6 +323,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp String.valueOf(limit))); } + void addQueueToPersist(final UUID accountUuid, final long deviceId) { + redisCluster.useBinaryCluster(connection -> connection.sync().zadd(getQueueIndexKey(accountUuid, deviceId), ZAddArgs.Builder.nx(), System.currentTimeMillis(), getMessageQueueKey(accountUuid, deviceId))); + } + void lockQueueForPersistence(final UUID accountUuid, final long deviceId) { redisCluster.useBinaryCluster(connection -> connection.sync().setex(getPersistInProgressKey(accountUuid, deviceId), 30, LOCK_VALUE)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 9ba8a7384..13185b8bd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -6,6 +6,7 @@ import org.apache.commons.lang3.RandomStringUtils; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; @@ -165,6 +166,25 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } + @Test + public void testPersistQueueRetry() { + final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + doAnswer((Answer)invocation -> { + throw new RuntimeException("OH NO."); + }).when(messagesDatabase).store(any(), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + + messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); + + assertEquals(List.of(queueName), + messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), Instant.now().plus(messagePersister.getPersistDelay()), 1)); + } + @SuppressWarnings("SameParameterValue") private static String generateRandomQueueNameForSlot(final int slot) { final UUID uuid = UUID.randomUUID();