Add API endpoints for waiting for account restoration requests

This commit is contained in:
Jon Chambers 2024-10-24 12:25:40 -04:00 committed by GitHub
parent 5c4cafcb6f
commit 324913d2da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 298 additions and 12 deletions

View File

@ -29,6 +29,7 @@ import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
import javax.ws.rs.Consumes;
@ -56,6 +57,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.DeviceInfoList;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.LinkDeviceResponse;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator;
@ -64,6 +66,7 @@ import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@ -437,6 +440,70 @@ public class DeviceController {
return isDowngrade;
}
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@Path("/restore_account/{token}")
@Operation(
summary = "Signals that a new device is requesting restoration of account data by some method",
description = """
Signals that a new device is requesting restoration of account data by some method. Devices waiting via the
"wait for 'restore account' request" endpoint will be notified that the request has been issued.
""")
@ApiResponse(responseCode = "204", description = "Success")
@ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@RateLimitedByIp(RateLimiters.For.RECORD_DEVICE_TRANSFER_REQUEST)
public CompletionStage<Void> recordRestoreAccountRequest(
@PathParam("token")
@NotBlank
@Size(max = 64)
@Schema(description = "A randomly-generated token identifying the request for device-to-device transfer.",
requiredMode = Schema.RequiredMode.REQUIRED,
maximum = "64") final String token,
@Valid
final RestoreAccountRequest restoreAccountRequest) {
return accounts.recordRestoreAccountRequest(token, restoreAccountRequest);
}
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/restore_account/{token}")
@Operation(summary = "Wait for 'restore account' request")
@ApiResponse(responseCode = "200", description = "A 'restore account' request was received for the given token",
content = @Content(schema = @Schema(implementation = RestoreAccountRequest.class)))
@ApiResponse(responseCode = "204", description = "No 'restore account' request for the given token was received before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given token or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@RateLimitedByIp(RateLimiters.For.WAIT_FOR_DEVICE_TRANSFER_REQUEST)
public CompletionStage<Response> waitForDeviceTransferRequest(
@PathParam("token")
@NotBlank
@Size(max = 64)
@Schema(description = "A randomly-generated token identifying the request for device-to-device transfer.",
requiredMode = Schema.RequiredMode.REQUIRED,
maximum = "64") final String token,
@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If a transfer archive for the authenticated
device is not available within the given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds) {
return accounts.waitForRestoreAccountRequest(token, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeRequestReceived -> maybeRequestReceived
.map(restoreAccountRequest -> Response.status(Response.Status.OK).entity(restoreAccountRequest).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()));
}
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)

View File

@ -0,0 +1,32 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import io.swagger.v3.oas.annotations.media.Schema;
import javax.validation.constraints.NotNull;
@Schema(description = """
Represents a request from a new device to restore account data by some method.
""")
public record RestoreAccountRequest(
@NotNull
@Schema(description = "The method by which the new device has requested account data restoration")
Method method) {
public enum Method {
@Schema(description = "Restore account data from a remote message history backup")
REMOTE_BACKUP,
@Schema(description = "Restore account data from a local backup archive")
LOCAL_BACKUP,
@Schema(description = "Restore account data via direct device-to-device transfer")
DEVICE_TRANSFER,
@Schema(description = "Do not restore account data")
DECLINE,
}
}

View File

@ -53,6 +53,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))),
;
private final String id;

View File

@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
@ -142,6 +143,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Map<TimestampedDeviceIdentifier, CompletableFuture<Optional<RemoteAttachment>>> waitForTransferArchiveFuturesByDeviceIdentifier =
new ConcurrentHashMap<>();
private final Map<String, CompletableFuture<Optional<RestoreAccountRequest>>> waitForRestoreAccountRequestFuturesByToken =
new ConcurrentHashMap<>();
private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength();
private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1);
@ -152,6 +156,10 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::";
private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*";
private static final Duration RESTORE_ACCOUNT_REQUEST_TTL = Duration.ofHours(1);
private static final String RESTORE_ACCOUNT_REQUEST_PREFIX = "restore_account::";
private static final String RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN = "__keyspace@0__:" + RESTORE_ACCOUNT_REQUEST_PREFIX + "*";
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@ -238,7 +246,8 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
public void start() {
pubSubConnection.usePubSubConnection(connection -> {
connection.addListener(this);
connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN);
connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN,
RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN);
});
}
@ -1496,6 +1505,44 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
":" + destinationDeviceCreationTimestamp.toEpochMilli();
}
public CompletableFuture<Optional<RestoreAccountRequest>> waitForRestoreAccountRequest(final String token, final Duration timeout) {
return waitForPubSubKey(waitForRestoreAccountRequestFuturesByToken,
token,
getRestoreAccountRequestKey(token),
timeout,
this::handleRestoreAccountRequest);
}
public CompletableFuture<Void> recordRestoreAccountRequest(final String token, final RestoreAccountRequest restoreAccountRequest) {
final String key = getRestoreAccountRequestKey(token);
final String requestJson;
try {
requestJson = SystemMapper.jsonMapper().writeValueAsString(restoreAccountRequest);
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
return pubSubRedisClient.withConnection(connection ->
connection.async().set(key, requestJson, SetArgs.Builder.ex(RESTORE_ACCOUNT_REQUEST_TTL)))
.thenRun(Util.NOOP)
.toCompletableFuture();
}
private void handleRestoreAccountRequest(final CompletableFuture<Optional<RestoreAccountRequest>> future, final String transferRequestJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(transferRequestJson, RestoreAccountRequest.class)));
} catch (final JsonProcessingException e) {
logger.error("Could not parse device transfer request JSON", e);
future.completeExceptionally(e);
}
}
private static String getRestoreAccountRequestKey(final String token) {
return RESTORE_ACCOUNT_REQUEST_PREFIX + token;
}
private <K, T> CompletableFuture<Optional<T>> waitForPubSubKey(final Map<K, CompletableFuture<Optional<T>>> futureMap,
final K mapKey,
final String redisKey,
@ -1564,6 +1611,14 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} catch (final IllegalArgumentException e) {
logger.error("Could not parse timestamped device identifier", e);
}
} else if (RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.equalsIgnoreCase(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String token = channel.substring(RESTORE_ACCOUNT_REQUEST_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForRestoreAccountRequestFuturesByToken.remove(token))
.ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(
getRestoreAccountRequestKey(token)))
.thenAccept(requestJson -> handleRestoreAccountRequest(future, requestJson)));
}
}

View File

@ -51,6 +51,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
@ -61,6 +62,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.LinkDeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@ -1019,7 +1021,7 @@ class DeviceControllerTest {
}
@ParameterizedTest
@MethodSource
@ValueSource(ints = {0, -1, 3601})
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
@ -1034,10 +1036,6 @@ class DeviceControllerTest {
}
}
private static List<Integer> waitForLinkedDeviceBadTimeout() {
return List.of(0, -1, 3601);
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
@ -1194,7 +1192,7 @@ class DeviceControllerTest {
}
@ParameterizedTest
@MethodSource
@ValueSource(ints = {0, -1, 3601})
void waitForTransferArchiveBadTimeout(final int timeoutSeconds) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
@ -1207,10 +1205,6 @@ class DeviceControllerTest {
}
}
private static List<Integer> waitForTransferArchiveBadTimeout() {
return List.of(0, -1, 3601);
}
@Test
void waitForTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(anyString()))
@ -1225,4 +1219,101 @@ class DeviceControllerTest {
assertEquals(429, response.getStatus());
}
}
@Test
void recordRestoreAccountRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
when(accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(204, response.getStatus());
}
}
@Test
void recordRestoreAccountRequestBadToken() {
final String token = RandomStringUtils.randomAlphanumeric(128);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(400, response.getStatus());
}
}
@Test
void recordRestoreAccountRequestInvalidRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest = new RestoreAccountRequest(null);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.put(Entity.json(restoreAccountRequest))) {
assertEquals(422, response.getStatus());
}
}
@Test
void waitForDeviceTransferRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.LOCAL_BACKUP);
when(accountsManager.waitForRestoreAccountRequest(eq(token), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(restoreAccountRequest)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.get()) {
assertEquals(200, response.getStatus());
assertEquals(restoreAccountRequest, response.readEntity(RestoreAccountRequest.class));
}
}
@Test
void waitForDeviceTransferRequestNoRequestIssued() {
final String token = RandomStringUtils.randomAlphanumeric(16);
when(accountsManager.waitForRestoreAccountRequest(eq(token), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.request()
.get()) {
assertEquals(204, response.getStatus());
}
}
@ParameterizedTest
@ValueSource(ints = {0, -1, 3601})
void waitForDeviceTransferRequestBadTimeout(final int timeoutSeconds) {
final String token = RandomStringUtils.randomAlphanumeric(16);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/restore_account/" + token)
.queryParam("timeout", timeoutSeconds)
.request()
.get()) {
assertEquals(400, response.getStatus());
}
}
}

View File

@ -5,11 +5,13 @@
package org.whispersystems.textsecuregcm.storage;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.RestoreAccountRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
@ -34,7 +36,7 @@ import static org.mockito.Mockito.when;
// ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be
// preempted by the timeout check
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class AccountsManagerTransferArchiveIntegrationTest {
public class AccountsManagerDeviceTransferIntegrationTest {
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
@ -144,4 +146,41 @@ public class AccountsManagerTransferArchiveIntegrationTest {
assertEquals(Optional.empty(),
accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join());
}
@Test
void waitForRestoreAccountRequest() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER);
final CompletableFuture<Optional<RestoreAccountRequest>> displacedFuture =
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5));
final CompletableFuture<Optional<RestoreAccountRequest>> activeFuture =
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join());
accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join();
assertEquals(Optional.of(restoreAccountRequest), activeFuture.join());
}
@Test
void waitForRestoreAccountRequestAlreadyRequested() {
final String token = RandomStringUtils.randomAlphanumeric(16);
final RestoreAccountRequest restoreAccountRequest =
new RestoreAccountRequest(RestoreAccountRequest.Method.DEVICE_TRANSFER);
accountsManager.recordRestoreAccountRequest(token, restoreAccountRequest).join();
assertEquals(Optional.of(restoreAccountRequest),
accountsManager.waitForRestoreAccountRequest(token, Duration.ofSeconds(5)).join());
}
@Test
void waitForRestoreAccountRequestTimeout() {
assertEquals(Optional.empty(),
accountsManager.waitForRestoreAccountRequest(RandomStringUtils.randomAlphanumeric(16), Duration.ofMillis(1)).join());
}
}