Forbid linked devices from setting backup-ids

This commit is contained in:
Ravi Khadiwala 2025-06-17 16:16:07 -05:00 committed by ravi-signal
parent 5de848bf38
commit 9dfe51eac4
6 changed files with 71 additions and 17 deletions

View File

@ -34,6 +34,7 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -85,6 +86,7 @@ public class BackupAuthManager {
* Store credential requests containing blinded backup-ids for future use. * Store credential requests containing blinded backup-ids for future use.
* *
* @param account The account using the backup-id * @param account The account using the backup-id
* @param device The device setting the account backup-id
* @param messagesBackupCredentialRequest A request containing the blinded backup-id the client will use to upload * @param messagesBackupCredentialRequest A request containing the blinded backup-id the client will use to upload
* message backups * message backups
* @param mediaBackupCredentialRequest A request containing the blinded backup-id the client will use to upload * @param mediaBackupCredentialRequest A request containing the blinded backup-id the client will use to upload
@ -92,12 +94,17 @@ public class BackupAuthManager {
* @return A future that completes when the credentialRequest has been stored * @return A future that completes when the credentialRequest has been stored
* @throws RateLimitExceededException If too many backup-ids have been committed * @throws RateLimitExceededException If too many backup-ids have been committed
*/ */
public CompletableFuture<Void> commitBackupId(final Account account, public CompletableFuture<Void> commitBackupId(
final Account account,
final Device device,
final BackupAuthCredentialRequest messagesBackupCredentialRequest, final BackupAuthCredentialRequest messagesBackupCredentialRequest,
final BackupAuthCredentialRequest mediaBackupCredentialRequest) { final BackupAuthCredentialRequest mediaBackupCredentialRequest) {
if (configuredBackupLevel(account).isEmpty()) { if (configuredBackupLevel(account).isEmpty()) {
throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException();
} }
if (!device.isPrimary()) {
throw Status.PERMISSION_DENIED.withDescription("Only primary device can set backup-id").asRuntimeException();
}
final byte[] serializedMessageCredentialRequest = messagesBackupCredentialRequest.serialize(); final byte[] serializedMessageCredentialRequest = messagesBackupCredentialRequest.serialize();
final byte[] serializedMediaCredentialRequest = mediaBackupCredentialRequest.serialize(); final byte[] serializedMediaCredentialRequest = mediaBackupCredentialRequest.serialize();

View File

@ -135,13 +135,14 @@ public class ArchiveController {
""") """)
@ApiResponse(responseCode = "204", description = "The backup-id was set") @ApiResponse(responseCode = "204", description = "The backup-id was set")
@ApiResponse(responseCode = "400", description = "The provided backup auth credential request was invalid") @ApiResponse(responseCode = "400", description = "The provided backup auth credential request was invalid")
@ApiResponse(responseCode = "403", description = "The device did not have permission to set the backup-id. Only the primary device can set the backup-id for an account")
@ApiResponse(responseCode = "429", description = "Rate limited. Too many attempts to change the backup-id have been made") @ApiResponse(responseCode = "429", description = "Rate limited. Too many attempts to change the backup-id have been made")
public CompletionStage<Response> setBackupId( public CompletionStage<Response> setBackupId(
@Mutable @Auth final AuthenticatedDevice account, @Mutable @Auth final AuthenticatedDevice account,
@Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException { @Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException {
return this.backupAuthManager return this.backupAuthManager
.commitBackupId(account.getAccount(), setBackupIdRequest.messagesBackupAuthCredentialRequest, .commitBackupId(account.getAccount(), account.getAuthenticatedDevice(),
setBackupIdRequest.messagesBackupAuthCredentialRequest,
setBackupIdRequest.mediaBackupAuthCredentialRequest) setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }

View File

@ -32,6 +32,7 @@ import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
@ -60,9 +61,15 @@ public class BackupsGrpcService extends ReactorBackupsGrpc.BackupsImplBase {
BackupAuthCredentialRequest::new, BackupAuthCredentialRequest::new,
request.getMediaBackupAuthCredentialRequest().toByteArray()); request.getMediaBackupAuthCredentialRequest().toByteArray());
final AuthenticatedDevice authenticatedDevice = AuthenticationUtil.requireAuthenticatedDevice();
return authenticatedAccount() return authenticatedAccount()
.flatMap(account -> Mono.fromFuture( .flatMap(account -> {
backupAuthManager.commitBackupId(account, messagesCredentialRequest, mediaCredentialRequest))) final Device device = account
.getDevice(authenticatedDevice.deviceId())
.orElseThrow(Status.UNAUTHENTICATED::asRuntimeException);
return Mono.fromFuture(
backupAuthManager.commitBackupId(account, device, messagesCredentialRequest, mediaCredentialRequest));
})
.thenReturn(SetBackupIdResponse.getDefaultInstance()); .thenReturn(SetBackupIdResponse.getDefaultInstance());
} }

View File

@ -61,6 +61,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
@ -119,7 +120,7 @@ public class BackupAuthManagerTest {
final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci); final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci);
final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci); final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci);
authManager.commitBackupId(account, messagesCredentialRequest, mediaCredentialRequest).join(); authManager.commitBackupId(account, primaryDevice(), messagesCredentialRequest, mediaCredentialRequest).join();
verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), mediaCredentialRequest.serialize()); verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), mediaCredentialRequest.serialize());
} }
@ -135,6 +136,7 @@ public class BackupAuthManagerTest {
final ThrowableAssert.ThrowingCallable commit = () -> final ThrowableAssert.ThrowingCallable commit = () ->
authManager.commitBackupId(account, authManager.commitBackupId(account,
primaryDevice(),
backupAuthTestUtil.getRequest(messagesBackupKey, aci), backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join(); backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join();
if (backupLevel == null) { if (backupLevel == null) {
@ -147,6 +149,24 @@ public class BackupAuthManagerTest {
} }
} }
@Test
void commitRequiresPrimary() {
final BackupAuthManager authManager = create(BackupLevel.FREE);
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account));
final ThrowableAssert.ThrowingCallable commit = () ->
authManager.commitBackupId(account,
linkedDevice(),
backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join();
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(commit)
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.PERMISSION_DENIED);
}
@CartesianTest @CartesianTest
void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel, void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel,
@CartesianTest.Enum final BackupCredentialType credentialType) { @CartesianTest.Enum final BackupCredentialType credentialType) {
@ -504,7 +524,7 @@ public class BackupAuthManagerTest {
: storedMediaCredential; : storedMediaCredential;
final boolean expectRateLimit = (changeMedia || changeMessage) && rateLimitBackupId; final boolean expectRateLimit = (changeMedia || changeMessage) && rateLimitBackupId;
final CompletableFuture<Void> future = authManager.commitBackupId(account, newMessagesCredential, newMediaCredential); final CompletableFuture<Void> future = authManager.commitBackupId(account, primaryDevice(), newMessagesCredential, newMediaCredential);
if (expectRateLimit) { if (expectRateLimit) {
CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future); CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future);
} else { } else {
@ -538,7 +558,7 @@ public class BackupAuthManagerTest {
// We should get rate limited iff we are out of paid media changes and we changed the media backup-id // We should get rate limited iff we are out of paid media changes and we changed the media backup-id
final boolean expectRateLimit = changeMedia && paid && rateLimitPaidMedia; final boolean expectRateLimit = changeMedia && paid && rateLimitPaidMedia;
final CompletableFuture<Void> future = authManager.commitBackupId(account, newMessagesCredential, newMediaCredential); final CompletableFuture<Void> future = authManager.commitBackupId(account, primaryDevice(), newMessagesCredential, newMediaCredential);
if (expectRateLimit) { if (expectRateLimit) {
CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future); CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, future);
} else { } else {
@ -562,6 +582,17 @@ public class BackupAuthManagerTest {
return account; return account;
} }
private Device primaryDevice() {
final Device device = mock(Device.class);
when(device.isPrimary()).thenReturn(true);
return device;
}
private Device linkedDevice() {
final Device device = mock(Device.class);
when(device.isPrimary()).thenReturn(false);
return device;
}
private static String experimentName(@Nullable BackupLevel backupLevel) { private static String experimentName(@Nullable BackupLevel backupLevel) {
return switch (backupLevel) { return switch (backupLevel) {

View File

@ -157,7 +157,7 @@ public class ArchiveControllerTest {
@Test @Test
public void setBackupId() { public void setBackupId() {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupAuthManager.commitBackupId(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/backupid") .target("v1/archives/backupid")
@ -170,7 +170,7 @@ public class ArchiveControllerTest {
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
verify(backupAuthManager).commitBackupId(AuthHelper.VALID_ACCOUNT, verify(backupAuthManager).commitBackupId(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE,
backupAuthTestUtil.getRequest(messagesBackupKey, aci), backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)); backupAuthTestUtil.getRequest(mediaBackupKey, aci));
} }
@ -275,9 +275,9 @@ public class ArchiveControllerTest {
@MethodSource @MethodSource
public void setBackupIdException(final Exception ex, final boolean sync, final int expectedStatus) { public void setBackupIdException(final Exception ex, final boolean sync, final int expectedStatus) {
if (sync) { if (sync) {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenThrow(ex); when(backupAuthManager.commitBackupId(any(), any(), any(), any())).thenThrow(ex);
} else { } else {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.failedFuture(ex)); when(backupAuthManager.commitBackupId(any(), any(), any(), any())).thenReturn(CompletableFuture.failedFuture(ex));
} }
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/backupid") .target("v1/archives/backupid")

View File

@ -57,6 +57,7 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics; import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -69,7 +70,8 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
backupAuthTestUtil.getRequest(mediaBackupKey, AUTHENTICATED_ACI); backupAuthTestUtil.getRequest(mediaBackupKey, AUTHENTICATED_ACI);
final BackupAuthCredentialRequest messagesAuthCredRequest = final BackupAuthCredentialRequest messagesAuthCredRequest =
backupAuthTestUtil.getRequest(messagesBackupKey, AUTHENTICATED_ACI); backupAuthTestUtil.getRequest(messagesBackupKey, AUTHENTICATED_ACI);
private final Account account = mock(Account.class); private Account account;
private Device device;
@Mock @Mock
private BackupAuthManager backupAuthManager; private BackupAuthManager backupAuthManager;
@ -83,14 +85,19 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
@BeforeEach @BeforeEach
void setup() { void setup() {
account = mock(Account.class);
device = mock(Device.class);
when(device.isPrimary()).thenReturn(true);
when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI)) when(accountsManager.getByAccountIdentifierAsync(AUTHENTICATED_ACI))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(account.getDevice(AUTHENTICATED_DEVICE_ID)).thenReturn(Optional.of(device));
} }
@Test @Test
void setBackupId() { void setBackupId() {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupAuthManager.commitBackupId(any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
authenticatedServiceStub().setBackupId( authenticatedServiceStub().setBackupId(
SetBackupIdRequest.newBuilder() SetBackupIdRequest.newBuilder()
@ -98,7 +105,7 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
.setMessagesBackupAuthCredentialRequest(ByteString.copyFrom(messagesAuthCredRequest.serialize())) .setMessagesBackupAuthCredentialRequest(ByteString.copyFrom(messagesAuthCredRequest.serialize()))
.build()); .build());
verify(backupAuthManager).commitBackupId(account, messagesAuthCredRequest, mediaAuthCredRequest); verify(backupAuthManager).commitBackupId(account, device, messagesAuthCredRequest, mediaAuthCredRequest);
} }
@Test @Test
@ -147,9 +154,10 @@ class BackupsGrpcServiceTest extends SimpleBaseGrpcTest<BackupsGrpcService, Back
@MethodSource @MethodSource
void setBackupIdException(final Exception ex, final boolean sync, final Status expected) { void setBackupIdException(final Exception ex, final boolean sync, final Status expected) {
if (sync) { if (sync) {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenThrow(ex); when(backupAuthManager.commitBackupId(any(), any(), any(), any())).thenThrow(ex);
} else { } else {
when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.failedFuture(ex)); when(backupAuthManager.commitBackupId(any(), any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(ex));
} }
GrpcTestUtils.assertStatusException( GrpcTestUtils.assertStatusException(