diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommand.java new file mode 100644 index 000000000..0377aa025 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommand.java @@ -0,0 +1,119 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; +import com.codahale.metrics.jdbi3.strategies.DefaultNameStrategy; +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.InvalidProtocolBufferException; +import io.dropwizard.Application; +import io.dropwizard.cli.EnvironmentCommand; +import io.dropwizard.jdbi3.JdbiFactory; +import io.dropwizard.setup.Environment; +import io.lettuce.core.ScanArgs; +import io.lettuce.core.ScanIterator; +import net.sourceforge.argparse4j.inf.Namespace; +import org.jdbi.v3.core.Jdbi; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; +import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; +import org.whispersystems.textsecuregcm.storage.Messages; +import org.whispersystems.textsecuregcm.util.Constants; + +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.codahale.metrics.MetricRegistry.name; + +public class ScourMessageCacheCommand extends EnvironmentCommand { + + private FaultTolerantRedisClient redisClient; + private Messages messageDatabase; + + private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private final Histogram queueSizeHistogram = metricRegistry.histogram(name(getClass(), "queueSize")); + + private static final Logger log = LoggerFactory.getLogger(ScourMessageCacheCommand.class); + + public ScourMessageCacheCommand() { + super(new Application<>() { + @Override + public void run(final WhisperServerConfiguration whisperServerConfiguration, final Environment environment) { + } + }, "scourmessagecache", "Persist and remove all message queues from the old message cache"); + } + + @Override + protected void run(final Environment environment, final Namespace namespace, final WhisperServerConfiguration config) { + final JdbiFactory jdbiFactory = new JdbiFactory(DefaultNameStrategy.CHECK_EMPTY); + final Jdbi messageJdbi = jdbiFactory.build(environment, config.getMessageStoreConfiguration(), "messagedb" ); + final FaultTolerantDatabase messageDatabase = new FaultTolerantDatabase("message_database", messageJdbi, config.getMessageStoreConfiguration().getCircuitBreakerConfiguration()); + + this.setMessageDatabase(new Messages(messageDatabase)); + this.setRedisClient(new FaultTolerantRedisClient("scourMessageCacheClient", config.getMessageCacheConfiguration().getRedisConfiguration())); + + scourMessageCache(); + } + + @VisibleForTesting + void setRedisClient(final FaultTolerantRedisClient redisClient) { + this.redisClient = redisClient; + } + + @VisibleForTesting + void setMessageDatabase(final Messages messageDatabase) { + this.messageDatabase = messageDatabase; + } + + @VisibleForTesting + void scourMessageCache() { + redisClient.useClient(connection -> ScanIterator.scan(connection.sync(), ScanArgs.Builder.matches("user_queue::*")) + .stream() + .forEach(this::persistQueue)); + } + + @VisibleForTesting + void persistQueue(final String queueKey) { + final String accountNumber; + { + final int startOfAccountNumber = queueKey.indexOf("::"); + accountNumber = queueKey.substring(startOfAccountNumber + 2, queueKey.indexOf("::", startOfAccountNumber + 1)); + } + + final long deviceId = Long.parseLong(queueKey.substring(queueKey.lastIndexOf("::") + 2)); + + final AtomicInteger messageCount = new AtomicInteger(0); + + redisClient.useBinaryClient(connection -> connection.sync().zrange(messageBytes -> { + persistMessage(accountNumber, deviceId, messageBytes); + messageCount.incrementAndGet(); + }, queueKey.getBytes(StandardCharsets.UTF_8), 0, Long.MAX_VALUE)); + + redisClient.useClient(connection -> { + final String accountNumberAndDeviceId = accountNumber + "::" + deviceId; + + connection.async().del("user_queue::" + accountNumberAndDeviceId, + "user_queue_metadata::" + accountNumberAndDeviceId, + "user_queue_persisting::" + accountNumberAndDeviceId); + }); + + queueSizeHistogram.update(messageCount.longValue()); + } + + private void persistMessage(final String accountNumber, final long deviceId, final byte[] message) { + try { + MessageProtos.Envelope envelope = MessageProtos.Envelope.parseFrom(message); + UUID guid = envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null; + + envelope = envelope.toBuilder().clearServerGuid().build(); + + messageDatabase.store(guid, envelope, accountNumber, deviceId); + } catch (InvalidProtocolBufferException e) { + log.error("Error parsing envelope", e); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommandTest.java new file mode 100644 index 000000000..3ce285b1a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/ScourMessageCacheCommandTest.java @@ -0,0 +1,84 @@ +package org.whispersystems.textsecuregcm.workers; + +import com.google.protobuf.ByteString; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Before; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.AbstractRedisSingletonTest; +import org.whispersystems.textsecuregcm.storage.Messages; +import org.whispersystems.textsecuregcm.storage.MessagesCache; + +import java.util.Random; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ScourMessageCacheCommandTest extends AbstractRedisSingletonTest { + + private Messages messageDatabase; + private MessagesCache messagesCache; + private ScourMessageCacheCommand scourMessageCacheCommand; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + + messageDatabase = mock(Messages.class); + messagesCache = new MessagesCache(getJedisPool()); + scourMessageCacheCommand = new ScourMessageCacheCommand(); + + scourMessageCacheCommand.setMessageDatabase(messageDatabase); + scourMessageCacheCommand.setRedisClient(getRedisClient()); + } + + @Test + public void testScourMessageCache() { + final int messageCount = insertDetachedMessages(100, 1_000); + + scourMessageCacheCommand.scourMessageCache(); + + verify(messageDatabase, times(messageCount)).store(any(UUID.class), any(MessageProtos.Envelope.class), anyString(), anyLong()); + assertEquals(0, (long)getRedisClient().withClient(connection -> connection.sync().dbsize())); + } + + @SuppressWarnings("SameParameterValue") + private int insertDetachedMessages(final int accounts, final int maxMessagesPerAccount) { + int totalMessages = 0; + + final Random random = new Random(); + + for (int i = 0; i < accounts; i++) { + final String accountNumber = String.format("+1800%07d", i); + final UUID accountUuid = UUID.randomUUID(); + final int messageCount = random.nextInt(maxMessagesPerAccount); + + for (int j = 0; j < messageCount; j++) { + final UUID messageGuid = UUID.randomUUID(); + + final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() + .setTimestamp(System.currentTimeMillis()) + .setServerTimestamp(System.currentTimeMillis()) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(messageGuid.toString()) + .build(); + + messagesCache.insert(messageGuid, accountNumber, accountUuid, 1, envelope); + } + + totalMessages += messageCount; + } + + getRedisClient().useClient(connection -> connection.sync().del("user_queue_index")); + + return totalMessages; + } +}