From be8a1acca948e2dd515294bb96490c1d8690862e Mon Sep 17 00:00:00 2001 From: Ehren Kret Date: Thu, 11 Feb 2021 10:50:03 -0600 Subject: [PATCH] Remove message database from the codebase (#395) * Remove message database from the codebase * Remove unused ExperimentEnrollmentManager in test * Be more stylish --- .../WhisperServerConfiguration.java | 9 - .../textsecuregcm/WhisperServerService.java | 50 ++- .../controllers/AccountController.java | 45 ++- .../controllers/DeviceController.java | 4 +- .../controllers/MessageController.java | 61 ++-- .../storage/AccountsManager.java | 15 +- .../storage/MessagePersister.java | 2 +- .../textsecuregcm/storage/Messages.java | 183 ---------- .../storage/MessagesManager.java | 85 ++--- .../OutgoingMessageEntityRowMapper.java | 44 --- .../websocket/WebSocketConnection.java | 45 ++- .../workers/DeleteUserCommand.java | 16 +- .../textsecuregcm/workers/VacuumCommand.java | 10 - service/src/main/resources/messagedb.xml | 139 -------- .../MessagePersisterIntegrationTest.java | 45 +-- .../storage/MessagePersisterTest.java | 63 ++-- .../controllers/DeviceControllerTest.java | 35 +- .../controllers/MessageControllerTest.java | 12 +- .../tests/storage/MessagesTest.java | 317 ------------------ .../WebSocketConnectionIntegrationTest.java | 89 ++--- .../websocket/WebSocketConnectionTest.java | 38 +-- 21 files changed, 258 insertions(+), 1049 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java delete mode 100644 service/src/main/resources/messagedb.xml delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 1b6742656..5594822d9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -134,11 +134,6 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private DynamoDbConfiguration keysDynamoDb; - @Valid - @NotNull - @JsonProperty - private DatabaseConfiguration messageStore; - @Valid @NotNull @JsonProperty @@ -316,10 +311,6 @@ public class WhisperServerConfiguration extends Configuration { return keysDynamoDb; } - public DatabaseConfiguration getMessageStoreConfiguration() { - return messageStore; - } - public DatabaseConfiguration getAbuseDatabaseConfiguration() { return abuseDatabase; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 61b98c25e..eb345ad13 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -4,6 +4,8 @@ */ package org.whispersystems.textsecuregcm; +import static com.codahale.metrics.MetricRegistry.name; + import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; @@ -38,6 +40,21 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.distribution.DistributionStatisticConfig; import io.micrometer.wavefront.WavefrontConfig; import io.micrometer.wavefront.WavefrontMeterRegistry; +import java.security.Security; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletRegistration; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.eclipse.jetty.servlets.CrossOriginFilter; import org.jdbi.v3.core.Jdbi; @@ -126,7 +143,6 @@ import org.whispersystems.textsecuregcm.storage.FeatureFlags; import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagePersister; -import org.whispersystems.textsecuregcm.storage.Messages; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -152,35 +168,17 @@ import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.DeleteFeatureFlagTask; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; -import org.whispersystems.textsecuregcm.workers.SetRequestLoggingEnabledTask; import org.whispersystems.textsecuregcm.workers.GetRedisCommandStatsCommand; import org.whispersystems.textsecuregcm.workers.GetRedisSlowlogCommand; import org.whispersystems.textsecuregcm.workers.ListFeatureFlagsTask; import org.whispersystems.textsecuregcm.workers.SetCrawlerAccelerationTask; import org.whispersystems.textsecuregcm.workers.SetFeatureFlagTask; +import org.whispersystems.textsecuregcm.workers.SetRequestLoggingEnabledTask; import org.whispersystems.textsecuregcm.workers.VacuumCommand; import org.whispersystems.textsecuregcm.workers.ZkParamsCommand; import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.setup.WebSocketEnvironment; -import javax.servlet.DispatcherType; -import javax.servlet.FilterRegistration; -import javax.servlet.ServletRegistration; -import java.security.Security; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; - -import static com.codahale.metrics.MetricRegistry.name; - public class WhisperServerService extends Application { static { @@ -204,13 +202,6 @@ public class WhisperServerService extends Application("messagedb", "messagedb.xml") { - @Override - public DataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) { - return configuration.getMessageStoreConfiguration(); - } - }); - bootstrap.addBundle(new NameableMigrationsBundle("abusedb", "abusedb.xml") { @Override public PooledDataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) { @@ -261,11 +252,9 @@ public class WhisperServerService extends Application messagesManager.clear(definitelyExistingAccount.getUuid())); pendingAccounts.remove(number); return account; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 8e7cf4d4d..98b4314c9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -102,7 +102,7 @@ public class DeviceController { account.removeDevice(deviceId); accounts.update(account); directoryQueue.refreshRegisteredUser(account); - messages.clear(account.getNumber(), account.getUuid(), deviceId); + messages.clear(account.getUuid(), deviceId); } @Timed @@ -196,7 +196,7 @@ public class DeviceController { device.setCapabilities(accountAttributes.getCapabilities()); account.get().addDevice(device); - messages.clear(account.get().getNumber(), account.get().getUuid(), device.getId()); + messages.clear(account.get().getUuid(), device.getId()); accounts.update(account.get()); pendingDevices.remove(number); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 4c26a4b80..c841edf81 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -4,6 +4,8 @@ */ package org.whispersystems.textsecuregcm.controllers; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; @@ -15,6 +17,25 @@ import io.dropwizard.auth.Auth; import io.dropwizard.util.DataSize; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; +import java.io.IOException; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import javax.validation.Valid; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.HeaderParam; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; @@ -46,28 +67,6 @@ import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; -import javax.validation.Valid; -import javax.ws.rs.Consumes; -import javax.ws.rs.DELETE; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import java.io.IOException; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; - -import static com.codahale.metrics.MetricRegistry.name; - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/messages") public class MessageController { @@ -226,11 +225,11 @@ public class MessageController { RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get())); } - final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(account.getNumber(), - account.getUuid(), - account.getAuthenticatedDevice().get().getId(), - userAgent, - false); + final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice( + account.getUuid(), + account.getAuthenticatedDevice().get().getId(), + userAgent, + false); outgoingMessageListSizeHistogram.update(outgoingMessages.getMessages().size()); @@ -271,8 +270,8 @@ public class MessageController { { try { WebSocketConnection.recordMessageDeliveryDuration(timestamp, account.getAuthenticatedDevice().get()); - Optional message = messagesManager.delete(account.getNumber(), - account.getUuid(), + Optional message = messagesManager.delete( + account.getUuid(), account.getAuthenticatedDevice().get().getId(), source, timestamp); @@ -291,8 +290,8 @@ public class MessageController { @Path("/uuid/{uuid}") public void removePendingMessage(@Auth Account account, @PathParam("uuid") UUID uuid) { try { - Optional message = messagesManager.delete(account.getNumber(), - account.getUuid(), + Optional message = messagesManager.delete( + account.getUuid(), account.getAuthenticatedDevice().get().getId(), uuid); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 4203451aa..404b40baa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -5,6 +5,8 @@ package org.whispersystems.textsecuregcm.storage; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; @@ -13,6 +15,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.micrometer.core.instrument.Metrics; +import java.io.IOException; +import java.util.List; +import java.util.Optional; +import java.util.UUID; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; @@ -22,13 +28,6 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; -import java.io.IOException; -import java.util.List; -import java.util.Optional; -import java.util.UUID; - -import static com.codahale.metrics.MetricRegistry.name; - public class AccountsManager { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); @@ -145,7 +144,7 @@ public class AccountsManager { directoryQueue.deleteAccount(account); profilesManager.deleteAll(account.getUuid()); keysDynamoDb.delete(account); - messagesManager.clear(account.getNumber(), account.getUuid()); + messagesManager.clear(account.getUuid()); redisDelete(account); databaseDelete(account); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index fd8142083..52267c02b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -164,7 +164,7 @@ public class MessagePersister implements Managed { do { messages = messagesCache.getMessagesToPersist(accountUuid, deviceId, MESSAGE_BATCH_LIMIT); - messagesManager.persistMessages(accountNumber, accountUuid, deviceId, messages); + messagesManager.persistMessages(accountUuid, deviceId, messages); messageCount += messages.size(); persistMessageMeter.mark(messages.size()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java deleted file mode 100644 index 71811c0f5..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import com.codahale.metrics.Histogram; -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import com.codahale.metrics.SharedMetricRegistries; -import com.codahale.metrics.Timer; -import org.jdbi.v3.core.argument.SetObjectArgumentFactory; -import org.jdbi.v3.core.statement.PreparedBatch; -import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.storage.mappers.OutgoingMessageEntityRowMapper; -import org.whispersystems.textsecuregcm.util.Constants; - -import java.sql.Types; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - -import static com.codahale.metrics.MetricRegistry.name; - -public class Messages { - - static final int RESULT_SET_CHUNK_SIZE = 100; - - public static final String ID = "id"; - public static final String GUID = "guid"; - public static final String TYPE = "type"; - public static final String RELAY = "relay"; - public static final String TIMESTAMP = "timestamp"; - public static final String SERVER_TIMESTAMP = "server_timestamp"; - public static final String SOURCE = "source"; - public static final String SOURCE_UUID = "source_uuid"; - public static final String SOURCE_DEVICE = "source_device"; - public static final String DESTINATION = "destination"; - public static final String DESTINATION_DEVICE = "destination_device"; - public static final String MESSAGE = "message"; - public static final String CONTENT = "content"; - - private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private final Timer storeTimer = metricRegistry.timer(name(Messages.class, "store" )); - private final Timer loadTimer = metricRegistry.timer(name(Messages.class, "load" )); - private final Timer removeBySourceTimer = metricRegistry.timer(name(Messages.class, "removeBySource")); - private final Timer removeByGuidTimer = metricRegistry.timer(name(Messages.class, "removeByGuid" )); - private final Timer removeByIdTimer = metricRegistry.timer(name(Messages.class, "removeById" )); - private final Timer clearDeviceTimer = metricRegistry.timer(name(Messages.class, "clearDevice" )); - private final Timer clearTimer = metricRegistry.timer(name(Messages.class, "clear" )); - private final Timer vacuumTimer = metricRegistry.timer(name(Messages.class, "vacuum")); - private final Meter insertNullGuidMeter = metricRegistry.meter(name(Messages.class, "insertNullGuid")); - private final Histogram storeSizeHistogram = metricRegistry.histogram(name(Messages.class, "storeBatchSize")); - - private final FaultTolerantDatabase database; - - private static class UUIDArgumentFactory extends SetObjectArgumentFactory { - public UUIDArgumentFactory() { - super(Map.of(UUID.class, Types.OTHER)); - } - } - - public Messages(FaultTolerantDatabase database) { - this.database = database; - this.database.getDatabase().registerRowMapper(new OutgoingMessageEntityRowMapper()); - this.database.getDatabase().registerArgument(new UUIDArgumentFactory()); - } - - public void store(final List messages, final String destination, final long destinationDevice) { - database.use(jdbi -> jdbi.useTransaction(handle -> { - try (final Timer.Context ignored = storeTimer.time()) { - final PreparedBatch batch = handle.prepareBatch("INSERT INTO messages (" + GUID + ", " + TYPE + ", " + RELAY + ", " + TIMESTAMP + ", " + SERVER_TIMESTAMP + ", " + SOURCE + ", " + SOURCE_UUID + ", " + SOURCE_DEVICE + ", " + DESTINATION + ", " + DESTINATION_DEVICE + ", " + MESSAGE + ", " + CONTENT + ") " + - "VALUES (:guid, :type, :relay, :timestamp, :server_timestamp, :source, :source_uuid, :source_device, :destination, :destination_device, :message, :content)"); - - for (final Envelope message : messages) { - if (message.getServerGuid() == null) { - insertNullGuidMeter.mark(); - } - - batch.bind("guid", UUID.fromString(message.getServerGuid())) - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .bind("type", message.getType().getNumber()) - .bind("relay", message.getRelay()) - .bind("timestamp", message.getTimestamp()) - .bind("server_timestamp", message.getServerTimestamp()) - .bind("source", message.hasSource() ? message.getSource() : null) - .bind("source_uuid", message.hasSourceUuid() ? UUID.fromString(message.getSourceUuid()) : null) - .bind("source_device", message.hasSourceDevice() ? message.getSourceDevice() : null) - .bind("message", message.hasLegacyMessage() ? message.getLegacyMessage().toByteArray() : null) - .bind("content", message.hasContent() ? message.getContent().toByteArray() : null) - .add(); - } - - batch.execute(); - storeSizeHistogram.update(messages.size()); - } - })); - } - - public List load(String destination, long destinationDevice) { - return database.with(jdbi-> jdbi.withHandle(handle -> { - try (Timer.Context ignored = loadTimer.time()) { - return handle.createQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC LIMIT " + RESULT_SET_CHUNK_SIZE) - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .mapTo(OutgoingMessageEntity.class) - .list(); - } - })); - } - - public Optional remove(String destination, long destinationDevice, String source, long timestamp) { - return database.with(jdbi -> jdbi.withHandle(handle -> { - try (Timer.Context ignored = removeBySourceTimer.time()) { - return handle.createQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *") - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .bind("source", source) - .bind("timestamp", timestamp) - .mapTo(OutgoingMessageEntity.class) - .findFirst(); - } - })); - } - - public Optional remove(String destination, UUID guid) { - return database.with(jdbi -> jdbi.withHandle(handle -> { - try (Timer.Context ignored = removeByGuidTimer.time()) { - return handle.createQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM MESSAGES WHERE " + GUID + " = :guid AND " + DESTINATION + " = :destination ORDER BY " + ID + " LIMIT 1) RETURNING *") - .bind("destination", destination) - .bind("guid", guid) - .mapTo(OutgoingMessageEntity.class) - .findFirst(); - } - })); - } - - public void remove(String destination, long id) { - database.use(jdbi -> jdbi.useHandle(handle -> { - try (Timer.Context ignored = removeByIdTimer.time()) { - handle.createUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination") - .bind("destination", destination) - .bind("id", id) - .execute(); - } - })); - } - - public void clear(String destination) { - database.use(jdbi ->jdbi.useHandle(handle -> { - try (Timer.Context ignored = clearTimer.time()) { - handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination") - .bind("destination", destination) - .execute(); - } - })); - } - - public void clear(String destination, long destinationDevice) { - database.use(jdbi -> jdbi.useHandle(handle -> { - try (Timer.Context ignored = clearDeviceTimer.time()) { - handle.createUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device") - .bind("destination", destination) - .bind("destination_device", destinationDevice) - .execute(); - } - })); - } - - public void vacuum() { - database.use(jdbi -> jdbi.useHandle(handle -> { - try (Timer.Context ignored = vacuumTimer.time()) { - handle.execute("VACUUM messages"); - } - })); - } - - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 4eade3aab..f1280b5a2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -2,10 +2,8 @@ * Copyright 2013-2020 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ - package org.whispersystems.textsecuregcm.storage; - import static com.codahale.metrics.MetricRegistry.name; import com.codahale.metrics.Meter; @@ -19,14 +17,13 @@ import java.util.stream.Collectors; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.util.Constants; public class MessagesManager { - private static final String DISABLE_RDS_EXPERIMENT = "messages_disable_rds"; + private static final int RESULT_SET_CHUNK_SIZE = 100; private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final Meter cacheHitByNameMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByName" )); @@ -34,18 +31,17 @@ public class MessagesManager { private static final Meter cacheHitByGuidMeter = metricRegistry.meter(name(MessagesManager.class, "cacheHitByGuid" )); private static final Meter cacheMissByGuidMeter = metricRegistry.meter(name(MessagesManager.class, "cacheMissByGuid")); - private final Messages messages; private final MessagesDynamoDb messagesDynamoDb; private final MessagesCache messagesCache; private final PushLatencyManager pushLatencyManager; - private final ExperimentEnrollmentManager experimentEnrollmentManager; - public MessagesManager(Messages messages, MessagesDynamoDb messagesDynamoDb, MessagesCache messagesCache, PushLatencyManager pushLatencyManager, ExperimentEnrollmentManager experimentEnrollmentManager) { - this.messages = messages; + public MessagesManager( + MessagesDynamoDb messagesDynamoDb, + MessagesCache messagesCache, + PushLatencyManager pushLatencyManager) { this.messagesDynamoDb = messagesDynamoDb; this.messagesCache = messagesCache; this.pushLatencyManager = pushLatencyManager; - this.experimentEnrollmentManager = experimentEnrollmentManager; } public void insert(UUID destinationUuid, long destinationDevice, Envelope message) { @@ -64,55 +60,38 @@ public class MessagesManager { return messagesCache.hasMessages(destinationUuid, destinationDevice); } - public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { + public OutgoingMessageEntityList getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); List messageList = new ArrayList<>(); - if (!cachedMessagesOnly && !experimentEnrollmentManager.isEnrolled(destinationUuid, DISABLE_RDS_EXPERIMENT)) { - messageList.addAll(messages.load(destination, destinationDevice)); + if (!cachedMessagesOnly) { + messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE)); } - if (messageList.size() < Messages.RESULT_SET_CHUNK_SIZE && !cachedMessagesOnly) { - messageList.addAll(messagesDynamoDb.load(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messageList.size())); + if (messageList.size() < RESULT_SET_CHUNK_SIZE) { + messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, RESULT_SET_CHUNK_SIZE - messageList.size())); } - if (messageList.size() < Messages.RESULT_SET_CHUNK_SIZE) { - messageList.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messageList.size())); - } - - return new OutgoingMessageEntityList(messageList, messageList.size() >= Messages.RESULT_SET_CHUNK_SIZE); + return new OutgoingMessageEntityList(messageList, messageList.size() >= RESULT_SET_CHUNK_SIZE); } - public void clear(String destination, UUID destinationUuid) { - // TODO Remove this null check in a fully-UUID-ified world - if (destinationUuid != null) { - messagesCache.clear(destinationUuid); - messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); - if (!experimentEnrollmentManager.isEnrolled(destinationUuid, DISABLE_RDS_EXPERIMENT)) { - messages.clear(destination); - } - } else { - messages.clear(destination); - } + public void clear(UUID destinationUuid) { + messagesCache.clear(destinationUuid); + messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid); } - public void clear(String destination, UUID destinationUuid, long deviceId) { + public void clear(UUID destinationUuid, long deviceId) { messagesCache.clear(destinationUuid, deviceId); messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, deviceId); - if (!experimentEnrollmentManager.isEnrolled(destinationUuid, DISABLE_RDS_EXPERIMENT)) { - messages.clear(destination, deviceId); - } } - public Optional delete(String destination, UUID destinationUuid, long destinationDevice, String source, long timestamp) { - Optional removed = messagesCache.remove(destinationUuid, destinationDevice, source, timestamp); + public Optional delete( + UUID destinationUuid, long destinationDeviceId, String source, long timestamp) { + Optional removed = messagesCache.remove(destinationUuid, destinationDeviceId, source, timestamp); if (removed.isEmpty()) { - removed = messagesDynamoDb.deleteMessageByDestinationAndSourceAndTimestamp(destinationUuid, destinationDevice, source, timestamp); - if (removed.isEmpty() && !experimentEnrollmentManager.isEnrolled(destinationUuid, DISABLE_RDS_EXPERIMENT)) { - removed = messages.remove(destination, destinationDevice, source, timestamp); - } + removed = messagesDynamoDb.deleteMessageByDestinationAndSourceAndTimestamp(destinationUuid, destinationDeviceId, source, timestamp); cacheMissByNameMeter.mark(); } else { cacheHitByNameMeter.mark(); @@ -121,14 +100,11 @@ public class MessagesManager { return removed; } - public Optional delete(String destination, UUID destinationUuid, long deviceId, UUID guid) { - Optional removed = messagesCache.remove(destinationUuid, deviceId, guid); + public Optional delete(UUID destinationUuid, long destinationDeviceId, UUID guid) { + Optional removed = messagesCache.remove(destinationUuid, destinationDeviceId, guid); if (removed.isEmpty()) { - removed = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, deviceId, guid); - if (removed.isEmpty() && !experimentEnrollmentManager.isEnrolled(destinationUuid, DISABLE_RDS_EXPERIMENT)) { - removed = messages.remove(destination, guid); - } + removed = messagesDynamoDb.deleteMessageByDestinationAndGuid(destinationUuid, destinationDeviceId, guid); cacheMissByGuidMeter.mark(); } else { cacheHitByGuidMeter.mark(); @@ -137,18 +113,19 @@ public class MessagesManager { return removed; } - @Deprecated - public void delete(String destination, long id) { - messages.remove(destination, id); - } - - public void persistMessages(final String destination, final UUID destinationUuid, final long destinationDeviceId, final List messages) { + public void persistMessages( + final UUID destinationUuid, + final long destinationDeviceId, + final List messages) { messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); messagesCache.remove(destinationUuid, destinationDeviceId, messages.stream().map(message -> UUID.fromString(message.getServerGuid())).collect(Collectors.toList())); } - public void addMessageAvailabilityListener(final UUID destinationUuid, final long deviceId, final MessageAvailabilityListener listener) { - messagesCache.addMessageAvailabilityListener(destinationUuid, deviceId, listener); + public void addMessageAvailabilityListener( + final UUID destinationUuid, + final long destinationDeviceId, + final MessageAvailabilityListener listener) { + messagesCache.addMessageAvailabilityListener(destinationUuid, destinationDeviceId, listener); } public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java deleted file mode 100644 index e6844e250..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/mappers/OutgoingMessageEntityRowMapper.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage.mappers; - -import org.jdbi.v3.core.mapper.RowMapper; -import org.jdbi.v3.core.statement.StatementContext; -import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.storage.Messages; - -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.UUID; - -public class OutgoingMessageEntityRowMapper implements RowMapper { - @Override - public OutgoingMessageEntity map(ResultSet resultSet, StatementContext ctx) throws SQLException { - int type = resultSet.getInt(Messages.TYPE); - byte[] legacyMessage = resultSet.getBytes(Messages.MESSAGE); - String guid = resultSet.getString(Messages.GUID); - String sourceUuid = resultSet.getString(Messages.SOURCE_UUID); - - if (type == Envelope.Type.RECEIPT_VALUE && legacyMessage == null) { - /// XXX - REMOVE AFTER 10/01/15 - legacyMessage = new byte[0]; - } - - return new OutgoingMessageEntity(resultSet.getLong(Messages.ID), - false, - guid == null ? null : UUID.fromString(guid), - type, - resultSet.getString(Messages.RELAY), - resultSet.getLong(Messages.TIMESTAMP), - resultSet.getString(Messages.SOURCE), - sourceUuid == null ? null : UUID.fromString(sourceUuid), - resultSet.getInt(Messages.SOURCE_DEVICE), - legacyMessage, - resultSet.getBytes(Messages.CONTENT), - resultSet.getLong(Messages.SERVER_TIMESTAMP)); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 19439b16b..912dcf880 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -5,6 +5,9 @@ package org.whispersystems.textsecuregcm.websocket; +import static com.codahale.metrics.MetricRegistry.name; +import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; + import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; @@ -13,6 +16,19 @@ import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.LongAdder; +import javax.ws.rs.WebApplicationException; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,23 +52,6 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import javax.ws.rs.WebApplicationException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.atomic.LongAdder; - -import static com.codahale.metrics.MetricRegistry.name; -import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener { @@ -144,7 +143,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac if (throwable == null) { if (isSuccessResponse(response)) { if (storedMessageInfo.isPresent()) { - messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), storedMessageInfo.get().getGuid()); + messagesManager.delete(account.getUuid(), device.getId(), storedMessageInfo.get().getGuid()); } if (message.getType() != Envelope.Type.RECEIPT) { @@ -218,7 +217,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { - final OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); + final OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); final CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; for (int i = 0; i < messages.getMessages().size(); i++) { @@ -250,12 +249,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final Envelope envelope = builder.build(); - if (message.getGuid() == null || (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient)) { - if (message.getGuid() == null) { - messagesManager.delete(account.getNumber(), message.getId()); // TODO(ehren): Remove once the message DB is gone. - } else { - messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), message.getGuid()); - } + if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { + messagesManager.delete(account.getUuid(), device.getId(), message.getGuid()); discardedMessagesMeter.mark(); sendFutures[i] = CompletableFuture.completedFuture(null); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java index 544e55cdf..b991bb0cd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/DeleteUserCommand.java @@ -5,6 +5,8 @@ package org.whispersystems.textsecuregcm.workers; +import static com.codahale.metrics.MetricRegistry.name; + import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.InstanceProfileCredentialsProvider; import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder; @@ -15,13 +17,14 @@ import io.dropwizard.cli.EnvironmentCommand; import io.dropwizard.jdbi3.JdbiFactory; import io.dropwizard.setup.Environment; import io.lettuce.core.resource.ClientResources; +import java.util.Optional; +import java.util.concurrent.ExecutorService; import net.sourceforge.argparse4j.inf.Namespace; import net.sourceforge.argparse4j.inf.Subparser; import org.jdbi.v3.core.Jdbi; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; @@ -31,7 +34,6 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; -import org.whispersystems.textsecuregcm.storage.Messages; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -41,11 +43,6 @@ import org.whispersystems.textsecuregcm.storage.ReservedUsernames; import org.whispersystems.textsecuregcm.storage.Usernames; import org.whispersystems.textsecuregcm.storage.UsernamesManager; -import java.util.Optional; -import java.util.concurrent.ExecutorService; - -import static com.codahale.metrics.MetricRegistry.name; - public class DeleteUserCommand extends EnvironmentCommand { private final Logger logger = LoggerFactory.getLogger(DeleteUserCommand.class); @@ -83,9 +80,7 @@ public class DeleteUserCommand extends EnvironmentCommand throws Exception { DatabaseConfiguration accountDbConfig = config.getAbuseDatabaseConfiguration(); - DatabaseConfiguration messageDbConfig = config.getMessageStoreConfiguration(); - Jdbi accountJdbi = Jdbi.create(accountDbConfig.getUrl(), accountDbConfig.getUser(), accountDbConfig.getPassword()); - Jdbi messageJdbi = Jdbi.create(messageDbConfig.getUrl(), messageDbConfig.getUser(), messageDbConfig.getPassword()); - FaultTolerantDatabase accountDatabase = new FaultTolerantDatabase("account_database_vacuum", accountJdbi, accountDbConfig.getCircuitBreakerConfiguration()); - FaultTolerantDatabase messageDatabase = new FaultTolerantDatabase("message_database_vacuum", messageJdbi, messageDbConfig.getCircuitBreakerConfiguration()); Accounts accounts = new Accounts(accountDatabase); PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase); - Messages messages = new Messages(messageDatabase); FeatureFlags featureFlags = new FeatureFlags(accountDatabase); logger.info("Vacuuming accounts..."); @@ -55,9 +48,6 @@ public class VacuumCommand extends ConfiguredCommand logger.info("Vacuuming pending_accounts..."); pendingAccounts.vacuum(); - logger.info("Vacuuming messages..."); - messages.vacuum(); - logger.info("Vacuuming feature flags..."); featureFlags.vacuum(); diff --git a/service/src/main/resources/messagedb.xml b/service/src/main/resources/messagedb.xml deleted file mode 100644 index 02f311dc2..000000000 --- a/service/src/main/resources/messagedb.xml +++ /dev/null @@ -1,139 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - CREATE RULE bounded_message_queue AS ON INSERT TO messages DO ALSO DELETE FROM messages WHERE id IN (SELECT id FROM messages WHERE destination = NEW.destination AND destination_device = NEW.destination_device ORDER BY timestamp DESC OFFSET 5000); - - - - DROP RULE bounded_message_queue ON messages; - CREATE RULE bounded_message_queue AS ON INSERT TO messages DO ALSO DELETE FROM messages WHERE id IN (SELECT id FROM messages WHERE destination = NEW.destination AND destination_device = NEW.destination_device ORDER BY timestamp DESC OFFSET 1000); - - - - - - - - DROP RULE bounded_message_queue ON messages; - - - - CREATE RULE bounded_message_queue AS ON INSERT TO messages DO ALSO DELETE FROM messages WHERE id IN (SELECT id FROM messages WHERE destination = NEW.destination AND destination_device = NEW.destination_device ORDER BY timestamp DESC OFFSET 1000); - - - - - - DROP RULE bounded_message_queue ON messages; - - - - CREATE RULE bounded_message_queue AS ON INSERT TO messages DO ALSO DELETE FROM messages WHERE id IN (SELECT id FROM messages WHERE destination = NEW.destination AND destination_device = NEW.destination_device ORDER BY timestamp DESC OFFSET 1000); - - - - DROP RULE bounded_message_queue ON messages; - - - - CREATE RULE bounded_message_queue AS ON INSERT TO messages DO ALSO DELETE FROM messages WHERE id IN (SELECT id FROM messages WHERE destination = NEW.destination AND destination_device = NEW.destination_device ORDER BY timestamp DESC OFFSET 1000); - - - - - - - - - - - - - - - - - CREATE INDEX CONCURRENTLY guid_index ON messages (guid); - - - - - - - - - - DROP INDEX CONCURRENTLY IF EXISTS public.destination_index; - - - diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index c5dfd97d9..1a3bcc6fd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -5,6 +5,10 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import com.amazonaws.services.dynamodbv2.document.DynamoDB; import com.amazonaws.services.dynamodbv2.document.Item; import com.amazonaws.services.dynamodbv2.document.ItemCollection; @@ -12,23 +16,7 @@ import com.amazonaws.services.dynamodbv2.document.ScanOutcome; import com.amazonaws.services.dynamodbv2.document.Table; import com.amazonaws.services.dynamodbv2.document.spec.ScanSpec; import com.google.protobuf.ByteString; -import com.opentable.db.postgres.embedded.LiquibasePreparer; -import com.opentable.db.postgres.junit.EmbeddedPostgresRules; -import com.opentable.db.postgres.junit.PreparedDbRule; import io.lettuce.core.cluster.SlotHash; -import org.apache.commons.lang3.RandomStringUtils; -import org.jdbi.v3.core.Jdbi; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; -import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; -import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; -import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; - import java.nio.ByteBuffer; import java.time.Duration; import java.time.Instant; @@ -40,18 +28,18 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; - -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; public class MessagePersisterIntegrationTest extends AbstractRedisClusterTest { - @Rule - public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("messagedb.xml")); - @Rule public MessagesDynamoDbRule messagesDynamoDbRule = new MessagesDynamoDbRule(); @@ -60,7 +48,6 @@ public class MessagePersisterIntegrationTest extends AbstractRedisClusterTest { private MessagesManager messagesManager; private MessagePersister messagePersister; private Account account; - private ExperimentEnrollmentManager experimentEnrollmentManager; private static final Duration PERSIST_DELAY = Duration.ofMinutes(10); @@ -74,16 +61,12 @@ public class MessagePersisterIntegrationTest extends AbstractRedisClusterTest { connection.sync().masters().commands().configSet("notify-keyspace-events", "K$glz"); }); - final Messages messages = new Messages(new FaultTolerantDatabase("messages-test", Jdbi.create(db.getTestDatabase()), new CircuitBreakerConfiguration())); final MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messagesDynamoDbRule.getDynamoDB(), MessagesDynamoDbRule.TABLE_NAME, Duration.ofDays(7)); final AccountsManager accountsManager = mock(AccountsManager.class); - experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); - when(experimentEnrollmentManager.isEnrolled(any(UUID.class), anyString())).thenReturn(Boolean.TRUE); - notificationExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); - messagesManager = new MessagesManager(messages, messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), experimentEnrollmentManager); + messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class)); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, mock(FeatureFlagsManager.class), PERSIST_DELAY); account = mock(Account.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 4af7e66f3..4ce3a270d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -5,16 +5,19 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import com.google.protobuf.ByteString; import io.lettuce.core.cluster.SlotHash; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.stubbing.Answer; -import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; - import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; @@ -24,24 +27,19 @@ import java.util.UUID; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; - -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; public class MessagePersisterTest extends AbstractRedisClusterTest { private ExecutorService notificationExecutorService; private MessagesCache messagesCache; - private Messages messagesDatabase; + private MessagesDynamoDb messagesDynamoDb; private MessagePersister messagePersister; private AccountsManager accountsManager; @@ -58,7 +56,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { final MessagesManager messagesManager = mock(MessagesManager.class); - messagesDatabase = mock(Messages.class); + messagesDynamoDb = mock(MessagesDynamoDb.class); accountsManager = mock(AccountsManager.class); final Account account = mock(Account.class); @@ -71,19 +69,18 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, mock(FeatureFlagsManager.class), PERSIST_DELAY); doAnswer(invocation -> { - final String destination = invocation.getArgument(0, String.class); - final UUID destinationUuid = invocation.getArgument(1, UUID.class); - final long deviceId = invocation.getArgument(2, Long.class); - final List messages = invocation.getArgument(3, List.class); + final UUID destinationUuid = invocation.getArgument(0); + final long destinationDeviceId = invocation.getArgument(1); + final List messages = invocation.getArgument(2); - messagesDatabase.store(messages, destination, deviceId); + messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); for (final MessageProtos.Envelope message : messages) { - messagesCache.remove(destinationUuid, deviceId, UUID.fromString(message.getServerGuid())); + messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); } return null; - }).when(messagesManager).persistMessages(anyString(), any(UUID.class), anyLong(), any()); + }).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any()); } @Override @@ -114,7 +111,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @@ -129,7 +126,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { messagePersister.persistNextQueues(now); - verify(messagesDatabase, never()).store(any(), anyString(), anyLong()); + verify(messagesDynamoDb, never()).store(any(), any(), anyLong()); } @Test @@ -159,7 +156,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - verify(messagesDatabase, atLeastOnce()).store(messagesCaptor.capture(), anyString(), anyLong()); + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong()); assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); } @@ -174,7 +171,7 @@ public class MessagePersisterTest extends AbstractRedisClusterTest { doAnswer((Answer)invocation -> { throw new RuntimeException("OH NO."); - }).when(messagesDatabase).store(any(), eq(DESTINATION_ACCOUNT_NUMBER), eq(DESTINATION_DEVICE_ID)); + }).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 83fa3df47..b919f28bd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -4,9 +4,25 @@ */ package org.whispersystems.textsecuregcm.tests.controllers; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit.ResourceTestRule; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import javax.ws.rs.Path; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; @@ -31,23 +47,6 @@ import org.whispersystems.textsecuregcm.storage.PendingDevicesManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.VerificationCode; -import javax.ws.rs.Path; -import javax.ws.rs.client.Entity; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.TimeUnit; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - @RunWith(JUnitParamsRunner.class) public class DeviceControllerTest { @Path("/v1/devices") @@ -143,7 +142,7 @@ public class DeviceControllerTest { assertThat(response.getDeviceId()).isEqualTo(42L); verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); - verify(messagesManager).clear(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(42L)); + verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index 6fc08a68d..fb7f5704f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -66,7 +66,6 @@ import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.FeatureFlagsManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.Base64; @@ -87,7 +86,6 @@ public class MessageControllerTest { private final RateLimiter rateLimiter = mock(RateLimiter.class); private final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class); private final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class); - private final FeatureFlagsManager featureFlagsManager = mock(FeatureFlagsManager.class); private final ObjectMapper mapper = new ObjectMapper(); @@ -281,7 +279,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -318,7 +316,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); Response response = resources.getJerseyTest().target("/v1/messages/") @@ -336,20 +334,20 @@ public class MessageControllerTest { UUID sourceUuid = UUID.randomUUID(); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31337)) + when(messagesManager.delete(AuthHelper.VALID_UUID, 1, "+14152222222", 31337)) .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, "+14152222222", sourceUuid, 1, "hi".getBytes(), null, 0))); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31338)) + when(messagesManager.delete(AuthHelper.VALID_UUID, 1, "+14152222222", 31338)) .thenReturn(Optional.of(new OutgoingMessageEntity(31337L, true, null, Envelope.Type.RECEIPT_VALUE, null, System.currentTimeMillis(), "+14152222222", sourceUuid, 1, null, null, 0))); - when(messagesManager.delete(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, 1, "+14152222222", 31339)) + when(messagesManager.delete(AuthHelper.VALID_UUID, 1, "+14152222222", 31339)) .thenReturn(Optional.empty()); Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java deleted file mode 100644 index 9fb46f799..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/MessagesTest.java +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.tests.storage; - -import com.google.protobuf.ByteString; -import com.opentable.db.postgres.embedded.LiquibasePreparer; -import com.opentable.db.postgres.junit.EmbeddedPostgresRules; -import com.opentable.db.postgres.junit.PreparedDbRule; -import junitparams.JUnitParamsRunner; -import junitparams.Parameters; -import org.jdbi.v3.core.Jdbi; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; -import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; -import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase; -import org.whispersystems.textsecuregcm.storage.Messages; - -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; -import java.util.Queue; -import java.util.Random; -import java.util.UUID; -import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.Collectors; - -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -@RunWith(JUnitParamsRunner.class) -public class MessagesTest { - - @Rule - public PreparedDbRule db = EmbeddedPostgresRules.preparedDatabase(LiquibasePreparer.forClasspathLocation("messagedb.xml")); - - private Messages messages; - - private long serialTimestamp = 0; - - @Before - public void setupAccountsDao() { - this.messages = new Messages(new FaultTolerantDatabase("messages-test", Jdbi.create(db.getTestDatabase()), new CircuitBreakerConfiguration())); - } - - @Test - public void testStore() throws SQLException { - Envelope envelope = generateEnvelope(); - - messages.store(List.of(envelope), "+14151112222", 1); - - PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?"); - statement.setString(1, "+14151112222"); - - ResultSet resultSet = statement.executeQuery(); - assertThat(resultSet.next()).isTrue(); - - assertThat(resultSet.getString("guid")).isEqualTo(envelope.getServerGuid()); - assertThat(resultSet.getInt("type")).isEqualTo(envelope.getType().getNumber()); - assertThat(resultSet.getString("relay")).isNullOrEmpty(); - assertThat(resultSet.getLong("timestamp")).isEqualTo(envelope.getTimestamp()); - assertThat(resultSet.getLong("server_timestamp")).isEqualTo(envelope.getServerTimestamp()); - assertThat(resultSet.getString("source")).isEqualTo(envelope.getSource()); - assertThat(resultSet.getLong("source_device")).isEqualTo(envelope.getSourceDevice()); - assertThat(resultSet.getBytes("message")).isEqualTo(envelope.getLegacyMessage().toByteArray()); - assertThat(resultSet.getBytes("content")).isEqualTo(envelope.getContent().toByteArray()); - assertThat(resultSet.getString("destination")).isEqualTo("+14151112222"); - assertThat(resultSet.getLong("destination_device")).isEqualTo(1); - - assertThat(resultSet.next()).isFalse(); - } - - @Test - @Parameters(method = "argumentsForTestStoreSealedSenderBatch") - public void testStoreSealedSenderBatch(final List sealedSenderSequence) throws Exception { - final String destinationNumber = "+14151234567"; - - final List envelopes = sealedSenderSequence.stream() - .map(sealedSender -> { - if (sealedSender) { - return generateEnvelope().toBuilder().clearSourceUuid().clearSource().clearSourceDevice().build(); - } else { - return generateEnvelope().toBuilder().setSourceUuid(UUID.randomUUID().toString()).setSource("+18005551234").setSourceDevice(4).build(); - } - }).collect(Collectors.toList()); - - messages.store(envelopes, destinationNumber, 1); - - final Queue expectedMessages = new ArrayDeque<>(envelopes); - - try (final PreparedStatement statement = db.getTestDatabase().getConnection().prepareStatement("SELECT * FROM messages WHERE destination = ?")) { - statement.setString(1, destinationNumber); - - try (final ResultSet resultSet = statement.executeQuery()) { - while (resultSet.next() && !expectedMessages.isEmpty()) { - assertRowEqualsEnvelope(resultSet, destinationNumber, expectedMessages.poll()); - } - - assertThat(resultSet.next()).isFalse(); - assertThat(expectedMessages.isEmpty()); - } - } - } - - private static Object argumentsForTestStoreSealedSenderBatch() { - return new Object[] { - List.of(true), - List.of(false), - List.of(true, false), - List.of(false, true) - }; - } - - private void assertRowEqualsEnvelope(final ResultSet resultSet, final String expectedDestination, final Envelope expectedMessage) throws SQLException { - assertThat(resultSet.getString("guid")).isEqualTo(expectedMessage.getServerGuid()); - assertThat(resultSet.getInt("type")).isEqualTo(expectedMessage.getType().getNumber()); - assertThat(resultSet.getString("relay")).isNullOrEmpty(); - assertThat(resultSet.getLong("timestamp")).isEqualTo(expectedMessage.getTimestamp()); - assertThat(resultSet.getLong("server_timestamp")).isEqualTo(expectedMessage.getServerTimestamp()); - assertThat(resultSet.getBytes("message")).isEqualTo(expectedMessage.getLegacyMessage().toByteArray()); - assertThat(resultSet.getBytes("content")).isEqualTo(expectedMessage.getContent().toByteArray()); - assertThat(resultSet.getString("destination")).isEqualTo(expectedDestination); - assertThat(resultSet.getLong("destination_device")).isEqualTo(1); - - if (expectedMessage.hasSource()) { - assertThat(resultSet.getString("source")).isEqualTo(expectedMessage.getSource()); - } else { - assertThat(resultSet.getString("source")).isNullOrEmpty(); - } - - if (expectedMessage.hasSourceDevice()) { - assertThat(resultSet.getLong("source_device")).isEqualTo(expectedMessage.getSourceDevice()); - } else { - assertThat(resultSet.getLong("source_device")).isEqualTo(0); - } - - if (expectedMessage.hasSourceUuid()) { - assertThat(resultSet.getString("source_uuid")).isEqualTo(expectedMessage.getSourceUuid()); - } else { - assertThat(resultSet.getString("source_uuid")).isNull(); - } - } - - @Test - public void testLoad() { - List inserted = insertRandom("+14151112222", 1); - - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); - - List retrieved = messages.load("+14151112222", 1); - - assertThat(retrieved.size()).isEqualTo(inserted.size()); - - for (int i=0;i inserted = insertRandom("+14151112222", 1); - List unrelated = insertRandom("+14151114444", 3); - Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", 1, toRemove.getSource(), toRemove.getTimestamp()); - - assertThat(removed.isPresent()).isTrue(); - verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid())); - - verifyInTact(inserted, "+14151112222", 1); - verifyInTact(unrelated, "+14151114444", 3); - } - - @Test - public void removeByDestinationGuid() { - List unrelated = insertRandom("+14151113333", 2); - List inserted = insertRandom("+14151112222", 1); - Envelope toRemove = inserted.remove(new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1)); - Optional removed = messages.remove("+14151112222", UUID.fromString(toRemove.getServerGuid())); - - assertThat(removed.isPresent()).isTrue(); - verifyExpected(removed.get(), toRemove, UUID.fromString(toRemove.getServerGuid())); - - verifyInTact(inserted, "+14151112222", 1); - verifyInTact(unrelated, "+14151113333", 2); - } - - @Test - public void removeByDestinationRowId() { - List unrelatedInserted = insertRandom("+14151111111", 1); - List inserted = insertRandom("+14151112222", 1); - - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); - - List retrieved = messages.load("+14151112222", 1); - - int toRemoveIndex = new Random(System.currentTimeMillis()).nextInt(inserted.size() - 1); - - inserted.remove(toRemoveIndex); - - messages.remove("+14151112222", retrieved.get(toRemoveIndex).getId()); - - verifyInTact(inserted, "+14151112222", 1); - verifyInTact(unrelatedInserted, "+14151111111", 1); - } - - @Test - public void testLoadEmpty() { - insertRandom("+14151112222", 1); - assertThat(messages.load("+14159999999", 1).isEmpty()).isTrue(); - } - - @Test - public void testClearDestination() { - insertRandom("+14151112222", 1); - insertRandom("+14151112222", 2); - - List unrelated = insertRandom("+14151111111", 1); - - messages.clear("+14151112222"); - - assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue(); - - verifyInTact(unrelated, "+14151111111", 1); - } - - @Test - public void testClearDestinationDevice() { - insertRandom("+14151112222", 1); - List inserted = insertRandom("+14151112222", 2); - - List unrelated = insertRandom("+14151111111", 1); - - messages.clear("+14151112222", 1); - - assertThat(messages.load("+14151112222", 1).isEmpty()).isTrue(); - - verifyInTact(inserted, "+14151112222", 2); - verifyInTact(unrelated, "+14151111111", 1); - } - - @Test - public void testVacuum() { - List inserted = insertRandom("+14151112222", 2); - messages.vacuum(); - verifyInTact(inserted, "+14151112222", 2); - } - - private List insertRandom(String destination, int destinationDevice) { - List inserted = new ArrayList<>(50); - - for (int i=0;i<50;i++) { - inserted.add(generateEnvelope()); - } - - messages.store(inserted, destination, destinationDevice); - - return inserted; - } - - private void verifyInTact(List inserted, String destination, int destinationDevice) { - inserted.sort(Comparator.comparingLong(Envelope::getTimestamp)); - - List retrieved = messages.load(destination, destinationDevice); - - assertThat(retrieved.size()).isEqualTo(inserted.size()); - - for (int i=0;i> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); verify(webSocketClient, times(persistedMessageCount + cachedMessageCount)).sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); @@ -203,7 +186,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest persistedMessages.add(generateRandomMessage(UUID.randomUUID())); } - messages.store(persistedMessages, account.getNumber(), device.getId()); + messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); } for (int i = 0; i < cachedMessageCount; i++) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index a04756524..77d3e8462 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -160,7 +160,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -192,7 +192,7 @@ public class WebSocketConnectionTest { futures.get(0).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException()); - verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid())); + verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid())); verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L)); connection.stop(); @@ -212,7 +212,7 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false)) .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false)); @@ -312,7 +312,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -363,7 +363,7 @@ public class WebSocketConnectionTest { final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); threadWaiting.notifyAll(); @@ -407,7 +407,7 @@ public class WebSocketConnectionTest { thread.join(); } - verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false)); + verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), anyString(), eq(false)); } @Test(timeout = 5000L) @@ -431,7 +431,7 @@ public class WebSocketConnectionTest { final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)) .thenReturn(firstPage) .thenReturn(secondPage); @@ -468,7 +468,7 @@ public class WebSocketConnectionTest { final List messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first")); final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -516,7 +516,7 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -554,7 +554,7 @@ public class WebSocketConnectionTest { final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(firstPage) .thenReturn(secondPage) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); @@ -592,7 +592,7 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -604,11 +604,11 @@ public class WebSocketConnectionTest { // anything. connection.processStoredMessages(); - verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), false); connection.handleNewMessagesAvailable(); - verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true); + verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), true); } @Test @@ -624,7 +624,7 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); @@ -637,7 +637,7 @@ public class WebSocketConnectionTest { connection.processStoredMessages(); connection.handleMessagesPersisted(); - verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), false); } @Test @@ -676,7 +676,7 @@ public class WebSocketConnectionTest { String userAgent = "Signal-Desktop/1.2.3"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -707,7 +707,7 @@ public class WebSocketConnectionTest { // We should delete all three messages even though we only sent two; one got discarded because it was too big for // desktop clients. - verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), any(UUID.class)); + verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class)); connection.stop(); verify(client).close(anyInt(), anyString()); @@ -749,7 +749,7 @@ public class WebSocketConnectionTest { String userAgent = "Signal-Android/4.68.3"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -779,7 +779,7 @@ public class WebSocketConnectionTest { futures.get(1).complete(response); futures.get(2).complete(response); - verify(storedMessages, times(3)).delete(eq(account.getNumber()), eq(accountUuid), eq(2L), any(UUID.class)); + verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class)); connection.stop(); verify(client).close(anyInt(), anyString());