diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 5a59c2bf0..1b9bc9f70 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -555,7 +555,7 @@ public class WhisperServerService extends Application(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))); environment.jersey().register(new TimestampResponseFilter()); environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, usernamesManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator, verifyExperimentEnrollmentManager)); - environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices())); + environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keysDynamoDb, directoryQueue, rateLimiters, config.getMaxDevices())); environment.jersey().register(new DirectoryController(directoryCredentialsGenerator)); environment.jersey().register(new ProvisioningController(rateLimiters, provisioningManager)); environment.jersey().register(new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, isZkEnabled)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index a7b76c165..8b3b7e39d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -23,6 +23,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.util.Util; @@ -58,6 +59,7 @@ public class DeviceController { private final StoredVerificationCodeManager pendingDevices; private final AccountsManager accounts; private final MessagesManager messages; + private final KeysDynamoDb keys; private final RateLimiters rateLimiters; private final Map maxDeviceConfiguration; private final DirectoryQueue directoryQueue; @@ -65,6 +67,7 @@ public class DeviceController { public DeviceController(StoredVerificationCodeManager pendingDevices, AccountsManager accounts, MessagesManager messages, + KeysDynamoDb keys, DirectoryQueue directoryQueue, RateLimiters rateLimiters, Map maxDeviceConfiguration) @@ -72,6 +75,7 @@ public class DeviceController { this.pendingDevices = pendingDevices; this.accounts = accounts; this.messages = messages; + this.keys = keys; this.directoryQueue = directoryQueue; this.rateLimiters = rateLimiters; this.maxDeviceConfiguration = maxDeviceConfiguration; @@ -102,6 +106,7 @@ public class DeviceController { messages.clear(account.getUuid(), deviceId); account = accounts.update(account, a -> a.removeDevice(deviceId)); directoryQueue.refreshRegisteredUser(account); + keys.delete(account, deviceId); // ensure any messages that came in after the first clear() are also removed messages.clear(account.getUuid(), deviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java index 3fcc9dbb1..fb0c924ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysDynamoDb.java @@ -168,8 +168,7 @@ public class KeysDynamoDb extends AbstractDynamoDbStore { }); } - @VisibleForTesting - void delete(final Account account, final long deviceId) { + public void delete(final Account account, final long deviceId) { DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { final QueryRequest queryRequest = QueryRequest.builder() .tableName(tableName) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 409bdc176..415a9caf2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -21,7 +22,6 @@ import io.dropwizard.testing.junit.ResourceTestRule; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeUnit; import javax.ws.rs.Path; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; @@ -33,6 +33,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentMatcher; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.controllers.DeviceController; @@ -46,6 +47,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; +import org.whispersystems.textsecuregcm.storage.KeysDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; @@ -59,11 +61,12 @@ public class DeviceControllerTest { public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices, AccountsManager accounts, MessagesManager messages, + KeysDynamoDb keys, DirectoryQueue cdsSender, RateLimiters rateLimiters, Map deviceConfiguration) { - super(pendingDevices, accounts, messages, cdsSender, rateLimiters, deviceConfiguration); + super(pendingDevices, accounts, messages, keys, cdsSender, rateLimiters, deviceConfiguration); } @Override @@ -75,6 +78,7 @@ public class DeviceControllerTest { private StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class); private AccountsManager accountsManager = mock(AccountsManager.class ); private MessagesManager messagesManager = mock(MessagesManager.class); + private KeysDynamoDb keys = mock(KeysDynamoDb.class); private DirectoryQueue directoryQueue = mock(DirectoryQueue.class); private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiter rateLimiter = mock(RateLimiter.class ); @@ -95,6 +99,7 @@ public class DeviceControllerTest { .addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, messagesManager, + keys, directoryQueue, rateLimiters, deviceConfiguration)) @@ -346,7 +351,7 @@ public class DeviceControllerTest { } @Test - public void deviceRemovalClearsMessages() { + public void deviceRemovalClearsMessagesAndKeys() { // this is a static mock, so it might have previous invocations clearInvocations(AuthHelper.VALID_ACCOUNT); @@ -366,6 +371,9 @@ public class DeviceControllerTest { verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId); verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any()); verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId); + + // The account instance may have changed as part of a call to `AccountManager#update` + verify(keys).delete(argThat(account -> account.getUuid().equals(AuthHelper.VALID_UUID)), eq(deviceId)); } }