Move "remove device" logic into `AccountsManager`

This commit is contained in:
Jon Chambers 2023-11-29 12:07:57 -05:00 committed by Jon Chambers
parent 4f42c10d60
commit 37e3bcfc3e
9 changed files with 88 additions and 58 deletions

View File

@ -135,8 +135,7 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{device_id}")
@ChangesDeviceEnabledState
public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") byte deviceId) {
Account account = auth.getAccount();
public CompletableFuture<Response> removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") byte deviceId) {
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -145,14 +144,7 @@ public class DeviceController {
throw new ForbiddenException();
}
final CompletableFuture<Void> deleteKeysFuture = keys.delete(account.getUuid(), deviceId);
messages.clear(account.getUuid(), deviceId).join();
account = accounts.update(account, a -> a.removeDevice(deviceId));
// ensure any messages that came in after the first clear() are also removed
messages.clear(account.getUuid(), deviceId).join();
deleteKeysFuture.join();
return accounts.removeDevice(auth.getAccount(), deviceId).thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@GET

View File

@ -78,19 +78,14 @@ public class DevicesGrpcService extends ReactorDevicesGrpc.DevicesImplBase {
if (request.getId() == Device.PRIMARY_ID) {
throw Status.INVALID_ARGUMENT.withDescription("Cannot remove primary device").asRuntimeException();
}
final byte deviceId = DeviceIdUtil.validate(request.getId());
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedPrimaryDevice();
return Mono.fromFuture(() -> accountsManager.getByAccountIdentifierAsync(authenticatedDevice.accountIdentifier()))
.map(maybeAccount -> maybeAccount.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException))
.flatMap(account -> Flux.merge(
Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId)),
Mono.fromFuture(() -> keysManager.delete(account.getUuid(), deviceId)))
.then(Mono.fromFuture(() -> accountsManager.updateAsync(account, a -> a.removeDevice(deviceId))))
// Some messages may have arrived while we were performing the other updates; make a best effort to clear
// those out, too
.then(Mono.fromFuture(() -> messagesManager.clear(account.getUuid(), deviceId))))
.flatMap(account -> Mono.fromFuture(accountsManager.removeDevice(account, deviceId)))
.thenReturn(RemoveDeviceResponse.newBuilder().build());
}

View File

@ -296,6 +296,25 @@ public class AccountsManager {
}
}
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
}
return CompletableFuture.allOf(
keysManager.delete(account.getUuid(), deviceId),
messagesManager.clear(account.getUuid(), deviceId))
.thenCompose(ignored -> updateAsync(account, (Consumer<Account>) a -> a.removeDevice(deviceId)))
// ensure any messages that came in after the first clear() are also removed
.thenCompose(updatedAccount -> messagesManager.clear(account.getUuid(), deviceId)
.thenApply(ignored -> updatedAccount))
.whenComplete((ignored, throwable) -> {
if (throwable == null) {
clientPresenceManager.disconnectPresence(account.getUuid(), deviceId);
}
});
}
public Account changeNumber(final Account account,
final String targetNumber,
@Nullable final IdentityKey pniIdentityKey,

View File

@ -16,11 +16,6 @@ import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Counter;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
import java.time.Duration;
import java.time.Instant;
import java.util.Comparator;
@ -29,12 +24,7 @@ import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -42,6 +32,10 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.ItemCollectionSizeLimitExceededException;
public class MessagePersister implements Managed {
@ -276,7 +270,7 @@ public class MessagePersister implements Managed {
messagesManager.clear(account.getUuid(), deviceToDelete.getId()))
.orTimeout((UNLINK_TIMEOUT.toSeconds() * 3) / 4, TimeUnit.SECONDS)
.join();
accountsManager.update(updatedAccount, a -> a.removeDevice(deviceToDelete.getId()));
accountsManager.removeDevice(updatedAccount, deviceToDelete.getId()).join();
} finally {
messagesCache.unlockAccountForMessagePersisterCleanup(account.getUuid());
if (deviceToDelete.getId() != destinationDeviceId) { // no point in persisting a device we just purged

View File

@ -71,18 +71,7 @@ public class UnlinkDeviceCommand extends EnvironmentCommand<WhisperServerConfigu
for (byte deviceId : deviceIds) {
/** see {@link org.whispersystems.textsecuregcm.controllers.DeviceController#removeDevice} */
System.out.format("Removing device %s::%d\n", aci, deviceId);
account = deps.accountsManager().update(account, a -> a.removeDevice(deviceId));
System.out.format("Removing keys for device %s::%d\n", aci, deviceId);
deps.keysManager().delete(account.getUuid(), deviceId).join();
System.out.format("Clearing additional messages for %s::%d\n", aci, deviceId);
deps.messagesManager().clear(account.getUuid(), deviceId).join();
System.out.format("Clearing presence state for %s::%d\n", aci, deviceId);
deps.clientPresenceManager().disconnectPresence(aci, deviceId);
System.out.format("Device %s::%d successfully removed\n", aci, deviceId);
deps.accountsManager().removeDevice(account, deviceId).join();
}
} finally {
commandStopListener.stop();

View File

@ -15,7 +15,6 @@ import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
@ -801,13 +800,16 @@ class DeviceControllerTest {
}
@Test
void deviceRemovalClearsMessagesAndKeys() {
void removeDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
final Response response = resources
.getJerseyTest()
.target("/v1/devices/" + deviceId)
@ -819,10 +821,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
verify(keysManager).delete(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT, deviceId);
}
@Test
@ -840,10 +839,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(403);
verify(messagesManager, never()).clear(any(), anyByte());
verify(accountsManager, never()).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT, never()).removeDevice(anyByte());
verify(keysManager, never()).delete(any(), anyByte());
verify(accountsManager, never()).removeDevice(any(), anyByte());
}
}

View File

@ -149,13 +149,14 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
void removeDevice() {
final byte deviceId = 17;
when(accountsManager.removeDevice(any(), anyByte()))
.thenReturn(CompletableFuture.completedFuture(authenticatedAccount));
final RemoveDeviceResponse ignored = authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(deviceId)
.build());
verify(messagesManager, times(2)).clear(AUTHENTICATED_ACI, deviceId);
verify(keysManager).delete(AUTHENTICATED_ACI, deviceId);
verify(authenticatedAccount).removeDevice(deviceId);
verify(accountsManager).removeDevice(authenticatedAccount, deviceId);
}
@Test
@ -163,6 +164,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
assertStatusException(Status.INVALID_ARGUMENT, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(1)
.build()));
verify(accountsManager, never()).removeDevice(any(), anyByte());
}
@Test
@ -171,6 +174,8 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest<DevicesGrpcService, Devi
assertStatusException(Status.PERMISSION_DENIED, () -> authenticatedServiceStub().removeDevice(RemoveDeviceRequest.newBuilder()
.setId(17)
.build()));
verify(accountsManager, never()).removeDevice(any(), anyByte());
}
@ParameterizedTest

View File

@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
@ -930,6 +931,47 @@ class AccountsManagerTest {
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testRemoveDevice() {
final Device primaryDevice = new Device();
primaryDevice.setId(Device.PRIMARY_ID);
final Device linkedDevice = new Device();
linkedDevice.setId((byte) (Device.PRIMARY_ID + 1));
Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice, linkedDevice));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
assertTrue(account.getDevice(linkedDevice.getId()).isPresent());
account = accountsManager.removeDevice(account, linkedDevice.getId()).join();
assertFalse(account.getDevice(linkedDevice.getId()).isPresent());
verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId());
verify(keysManager).delete(account.getUuid(), linkedDevice.getId());
verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId());
}
@Test
void testRemovePrimaryDevice() {
final Device primaryDevice = new Device();
primaryDevice.setId(Device.PRIMARY_ID);
final Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice));
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
assertThrows(IllegalArgumentException.class, () -> accountsManager.removeDevice(account, Device.PRIMARY_ID));
assertTrue(account.getPrimaryDevice().isPresent());
verify(messagesManager, never()).clear(any(), anyByte());
verify(keysManager, never()).delete(any(), anyByte());
verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte());
}
@Test
void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any(), any())).thenReturn(true);

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
@ -17,8 +16,6 @@ import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.util.MockUtils.exactly;
@ -49,8 +46,6 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
@ -93,6 +88,9 @@ class MessagePersisterTest {
destinationAccount = mock(Account.class);;
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(destinationAccount));
when(accountsManager.removeDevice(any(), anyByte()))
.thenAnswer(invocation -> CompletableFuture.completedFuture(invocation.getArgument(0)));
when(destinationAccount.getUuid()).thenReturn(DESTINATION_ACCOUNT_UUID);
when(destinationAccount.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());