diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index f2d7abddb..bec6a9505 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -29,6 +29,7 @@ import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.Max; import javax.validation.constraints.Min; +import javax.validation.constraints.NotBlank; import javax.validation.constraints.NotNull; import javax.validation.constraints.Size; import javax.ws.rs.Consumes; @@ -56,6 +57,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfoList; +import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; import org.whispersystems.textsecuregcm.entities.LinkDeviceResponse; import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; @@ -64,6 +66,7 @@ import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest; import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest; import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -437,6 +440,70 @@ public class DeviceController { return isDowngrade; } + @PUT + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + @Path("/restore_account/{token}") + @Operation( + summary = "Signals that a new device is requesting restoration of account data by some method", + description = """ + Signals that a new device is requesting restoration of account data by some method. Devices waiting via the + "wait for 'restore account' request" endpoint will be notified that the request has been issued. + """) + @ApiResponse(responseCode = "204", description = "Success") + @ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid") + @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") + @RateLimitedByIp(RateLimiters.For.RECORD_DEVICE_TRANSFER_REQUEST) + public CompletionStage recordRestoreAccountRequest( + @PathParam("token") + @NotBlank + @Size(max = 64) + @Schema(description = "A randomly-generated token identifying the request for device-to-device transfer.", + requiredMode = Schema.RequiredMode.REQUIRED, + maximum = "64") final String token, + + @Valid + final RestoreAccountRequest restoreAccountRequest) { + + return accounts.recordRestoreAccountRequest(token, restoreAccountRequest); + } + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/restore_account/{token}") + @Operation(summary = "Wait for 'restore account' request") + @ApiResponse(responseCode = "200", description = "A 'restore account' request was received for the given token", + content = @Content(schema = @Schema(implementation = RestoreAccountRequest.class))) + @ApiResponse(responseCode = "204", description = "No 'restore account' request for the given token was received before the call completed; clients may repeat the call to continue waiting") + @ApiResponse(responseCode = "400", description = "The given token or timeout was invalid") + @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") + @RateLimitedByIp(RateLimiters.For.WAIT_FOR_DEVICE_TRANSFER_REQUEST) + public CompletionStage waitForDeviceTransferRequest( + @PathParam("token") + @NotBlank + @Size(max = 64) + @Schema(description = "A randomly-generated token identifying the request for device-to-device transfer.", + requiredMode = Schema.RequiredMode.REQUIRED, + maximum = "64") final String token, + + @QueryParam("timeout") + @DefaultValue("30") + @Min(1) + @Max(3600) + @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, + minimum = "1", + maximum = "3600", + description = """ + The amount of time (in seconds) to wait for a response. If a transfer archive for the authenticated + device is not available within the given amount of time, this endpoint will return a status of HTTP/204. + """) final int timeoutSeconds) { + + return accounts.waitForRestoreAccountRequest(token, Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeRequestReceived -> maybeRequestReceived + .map(restoreAccountRequest -> Response.status(Response.Status.OK).entity(restoreAccountRequest).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())); + } + @PUT @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RestoreAccountRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RestoreAccountRequest.java new file mode 100644 index 000000000..89ecf5a4e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RestoreAccountRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.NotNull; + +@Schema(description = """ + Represents a request from a new device to restore account data by some method. + """) +public record RestoreAccountRequest( + @NotNull + @Schema(description = "The method by which the new device has requested account data restoration") + Method method) { + + public enum Method { + @Schema(description = "Restore account data from a remote message history backup") + REMOTE_BACKUP, + + @Schema(description = "Restore account data from a local backup archive") + LOCAL_BACKUP, + + @Schema(description = "Restore account data via direct device-to-device transfer") + DEVICE_TRANSFER, + + @Schema(description = "Do not restore account data") + DECLINE, + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index aa07085e6..1d210a253 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -53,6 +53,8 @@ public class RateLimiters extends BaseRateLimiters { WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))), WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), + RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))), + WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))), ; private final String id; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 344dd6826..0b30f728a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceInfo; +import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.RemoteAttachment; @@ -142,6 +143,9 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final Map>> waitForTransferArchiveFuturesByDeviceIdentifier = new ConcurrentHashMap<>(); + private final Map>> waitForRestoreAccountRequestFuturesByToken = + new ConcurrentHashMap<>(); + private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength(); private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1); @@ -152,6 +156,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::"; private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*"; + private static final Duration RESTORE_ACCOUNT_REQUEST_TTL = Duration.ofHours(1); + private static final String RESTORE_ACCOUNT_REQUEST_PREFIX = "restore_account::"; + private static final String RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN = "__keyspace@0__:" + RESTORE_ACCOUNT_REQUEST_PREFIX + "*"; + private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() .writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); @@ -238,7 +246,8 @@ public class AccountsManager extends RedisPubSubAdapter implemen public void start() { pubSubConnection.usePubSubConnection(connection -> { connection.addListener(this); - connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN); + connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN, + RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN); }); } @@ -1496,6 +1505,44 @@ public class AccountsManager extends RedisPubSubAdapter implemen ":" + destinationDeviceCreationTimestamp.toEpochMilli(); } + public CompletableFuture> waitForRestoreAccountRequest(final String token, final Duration timeout) { + return waitForPubSubKey(waitForRestoreAccountRequestFuturesByToken, + token, + getRestoreAccountRequestKey(token), + timeout, + this::handleRestoreAccountRequest); + } + + public CompletableFuture recordRestoreAccountRequest(final String token, final RestoreAccountRequest restoreAccountRequest) { + final String key = getRestoreAccountRequestKey(token); + + final String requestJson; + + try { + requestJson = SystemMapper.jsonMapper().writeValueAsString(restoreAccountRequest); + } catch (final JsonProcessingException e) { + throw new UncheckedIOException(e); + } + + return pubSubRedisClient.withConnection(connection -> + connection.async().set(key, requestJson, SetArgs.Builder.ex(RESTORE_ACCOUNT_REQUEST_TTL))) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } + + private void handleRestoreAccountRequest(final CompletableFuture> future, final String transferRequestJson) { + try { + future.complete(Optional.of(SystemMapper.jsonMapper().readValue(transferRequestJson, RestoreAccountRequest.class))); + } catch (final JsonProcessingException e) { + logger.error("Could not parse device transfer request JSON", e); + future.completeExceptionally(e); + } + } + + private static String getRestoreAccountRequestKey(final String token) { + return RESTORE_ACCOUNT_REQUEST_PREFIX + token; + } + private CompletableFuture> waitForPubSubKey(final Map>> futureMap, final K mapKey, final String redisKey, @@ -1564,6 +1611,14 @@ public class AccountsManager extends RedisPubSubAdapter implemen } catch (final IllegalArgumentException e) { logger.error("Could not parse timestamped device identifier", e); } + } else if (RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.equalsIgnoreCase(pattern) && "set".equalsIgnoreCase(message)) { + // The `- 1` here compensates for the '*' in the pattern + final String token = channel.substring(RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.length() - 1); + + Optional.ofNullable(waitForRestoreAccountRequestFuturesByToken.remove(token)) + .ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get( + getRestoreAccountRequestKey(token))) + .thenAccept(requestJson -> handleRestoreAccountRequest(future, requestJson))); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index adf75bed0..01f7f9eb7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -51,6 +51,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; @@ -61,6 +62,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceInfo; +import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; import org.whispersystems.textsecuregcm.entities.LinkDeviceResponse; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; @@ -1019,7 +1021,7 @@ class DeviceControllerTest { } @ParameterizedTest - @MethodSource + @ValueSource(ints = {0, -1, 3601}) void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); @@ -1034,10 +1036,6 @@ class DeviceControllerTest { } } - private static List waitForLinkedDeviceBadTimeout() { - return List.of(0, -1, 3601); - } - @ParameterizedTest @MethodSource void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) { @@ -1194,7 +1192,7 @@ class DeviceControllerTest { } @ParameterizedTest - @MethodSource + @ValueSource(ints = {0, -1, 3601}) void waitForTransferArchiveBadTimeout(final int timeoutSeconds) { try (final Response response = resources.getJerseyTest() .target("/v1/devices/transfer_archive/") @@ -1207,10 +1205,6 @@ class DeviceControllerTest { } } - private static List waitForTransferArchiveBadTimeout() { - return List.of(0, -1, 3601); - } - @Test void waitForTransferArchiveRateLimited() { when(rateLimiter.validateAsync(anyString())) @@ -1225,4 +1219,101 @@ class DeviceControllerTest { assertEquals(429, response.getStatus()); } } + + @Test + void recordRestoreAccountRequest() { + final String token = RandomStringUtils.randomAlphanumeric(16); + final RestoreAccountRequest restoreAccountRequest = + new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP); + + when(accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest)) + .thenReturn(CompletableFuture.completedFuture(null)); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .request() + .put(Entity.json(restoreAccountRequest))) { + + assertEquals(204, response.getStatus()); + } + } + + @Test + void recordRestoreAccountRequestBadToken() { + final String token = RandomStringUtils.randomAlphanumeric(128); + final RestoreAccountRequest restoreAccountRequest = + new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .request() + .put(Entity.json(restoreAccountRequest))) { + + assertEquals(400, response.getStatus()); + } + } + + @Test + void recordRestoreAccountRequestInvalidRequest() { + final String token = RandomStringUtils.randomAlphanumeric(16); + final RestoreAccountRequest restoreAccountRequest = new RestoreAccountRequest(null); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .request() + .put(Entity.json(restoreAccountRequest))) { + + assertEquals(422, response.getStatus()); + } + } + + @Test + void waitForDeviceTransferRequest() { + final String token = RandomStringUtils.randomAlphanumeric(16); + final RestoreAccountRequest restoreAccountRequest = + new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP); + + when(accountsManager.waitForRestoreAccountRequest(eq(token), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(restoreAccountRequest))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .request() + .get()) { + + assertEquals(200, response.getStatus()); + assertEquals(restoreAccountRequest, response.readEntity(RestoreAccountRequest.class)); + } + } + + @Test + void waitForDeviceTransferRequestNoRequestIssued() { + final String token = RandomStringUtils.randomAlphanumeric(16); + + when(accountsManager.waitForRestoreAccountRequest(eq(token), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .request() + .get()) { + + assertEquals(204, response.getStatus()); + } + } + + @ParameterizedTest + @ValueSource(ints = {0, -1, 3601}) + void waitForDeviceTransferRequestBadTimeout(final int timeoutSeconds) { + final String token = RandomStringUtils.randomAlphanumeric(16); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/restore_account/" + token) + .queryParam("timeout", timeoutSeconds) + .request() + .get()) { + + assertEquals(400, response.getStatus()); + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java similarity index 76% rename from service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java index db58467ff..3dfe355cf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerDeviceTransferIntegrationTest.java @@ -5,11 +5,13 @@ package org.whispersystems.textsecuregcm.storage; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest; import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; @@ -34,7 +36,7 @@ import static org.mockito.Mockito.when; // ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be // preempted by the timeout check @Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) -public class AccountsManagerTransferArchiveIntegrationTest { +public class AccountsManagerDeviceTransferIntegrationTest { @RegisterExtension static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); @@ -144,4 +146,41 @@ public class AccountsManagerTransferArchiveIntegrationTest { assertEquals(Optional.empty(), accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join()); } + + @Test + void waitForRestoreAccountRequest() { + final String token = RandomStringUtils.randomAlphanumeric(16); + final RestoreAccountRequest restoreAccountRequest = + new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER); + + final CompletableFuture> displacedFuture = + accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5)); + + final CompletableFuture> activeFuture = + accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5)); + + assertEquals(Optional.empty(), displacedFuture.join()); + + accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join(); + + assertEquals(Optional.of(restoreAccountRequest), activeFuture.join()); + } + + @Test + void waitForRestoreAccountRequestAlreadyRequested() { + final String token = RandomStringUtils.randomAlphanumeric(16); + final RestoreAccountRequest restoreAccountRequest = + new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER); + + accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join(); + + assertEquals(Optional.of(restoreAccountRequest), + accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5)).join()); + } + + @Test + void waitForRestoreAccountRequestTimeout() { + assertEquals(Optional.empty(), + accountsManager.waitForRestoreAccountRequest(RandomStringUtils.randomAlphanumeric(16), Duration.ofMillis(1)).join()); + } }