Add regression test for set profile badges calculation

This commit is contained in:
Chris Eager 2025-04-11 15:34:12 -05:00 committed by Chris Eager
parent 7cac6f6f72
commit 0585f862cb
3 changed files with 164 additions and 2 deletions

View File

@ -1140,6 +1140,58 @@ class ProfileControllerTest {
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false), new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false)); new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false));
} }
}
@Test
void testSetProfileBadgeAfterUpdateTries() throws Exception {
final ProfileKeyCommitment commitment = new ProfileKey(new byte[32]).getCommitment(
new ServiceId.Aci(AuthHelper.VALID_UUID));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
final String version = versionHex("anotherversion");
clearInvocations(AuthHelper.VALID_ACCOUNT_TWO);
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true)
))
.thenReturn(List.of(
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), true)
));
try (final Response response = resources.getJerseyTest()
.target("/v1/profile/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.put(Entity.entity(new CreateProfileRequest(commitment, version, name, emoji, about, null, false, false,
Optional.of(List.of("TEST1")), null), MediaType.APPLICATION_JSON_TYPE))) {
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.hasEntity()).isFalse();
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(AuthHelper.VALID_ACCOUNT_TWO, times(accountsManagerUpdateRetryCount)).getBadges();
final List<AccountBadge> badges = badgeCaptor.getValue();
assertThat(badges).isNotNull().hasSize(4).containsOnly(
new AccountBadge("TEST1", Instant.ofEpochSecond(42 + 86400), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST3", Instant.ofEpochSecond(42 + 86400), false),
new AccountBadge("TEST4", Instant.ofEpochSecond(43 + 86400), false));
}
} }
@ParameterizedTest @ParameterizedTest

View File

@ -15,7 +15,9 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -30,6 +32,7 @@ import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -93,11 +96,13 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceCapability;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.VersionedProfile; import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper; import org.whispersystems.textsecuregcm.tests.util.ProfileTestHelper;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
@ -144,6 +149,8 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
@Mock @Mock
private ServerZkProfileOperations serverZkProfileOperations; private ServerZkProfileOperations serverZkProfileOperations;
private Clock clock;
@Override @Override
protected ProfileGrpcService createServiceBeforeEachTest() { protected ProfileGrpcService createServiceBeforeEachTest() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class); @SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@ -203,8 +210,10 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null)); when(asyncS3client.deleteObject(any(DeleteObjectRequest.class))).thenReturn(CompletableFuture.completedFuture(null));
clock = Clock.fixed(Instant.ofEpochSecond(42), ZoneId.of("Etc/UTC"));
return new ProfileGrpcService( return new ProfileGrpcService(
Clock.systemUTC(), clock,
accountsManager, accountsManager,
profilesManager, profilesManager,
dynamicConfigurationManager, dynamicConfigurationManager,
@ -392,6 +401,42 @@ public class ProfileGrpcServiceTest extends SimpleBaseGrpcTest<ProfileGrpcServic
} }
} }
@Test
void setProfileBadges() throws InvalidInputException {
final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(AUTHENTICATED_ACI)).serialize();
final SetProfileRequest request = SetProfileRequest.newBuilder()
.setVersion(VERSION)
.setName(ByteString.copyFrom(VALID_NAME))
.setAvatarChange(AvatarChange.AVATAR_CHANGE_UNCHANGED)
.addAllBadgeIds(List.of("TEST3"))
.setCommitment(ByteString.copyFrom(commitment))
.build();
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
// set up two invocations -- one for each AccountsManager#update try
when(account.getBadges())
.thenReturn(List.of(new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)))
.thenReturn(List.of(new AccountBadge("TEST2", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST3", Instant.ofEpochSecond(41), false)));
//noinspection ResultOfMethodCallIgnored
authenticatedServiceStub().setProfile(request);
//noinspection unchecked
final ArgumentCaptor<List<AccountBadge>> badgeCaptor = ArgumentCaptor.forClass(List.class);
verify(account, times(2)).setBadges(refEq(clock), badgeCaptor.capture());
// since the stubbing of getBadges() is brittle, we need to verify the number of invocations, to protect against upstream changes
verify(account, times(accountsManagerUpdateRetryCount)).getBadges();
assertEquals(List.of(
new AccountBadge("TEST3", Instant.ofEpochSecond(41), true),
new AccountBadge("TEST2", Instant.ofEpochSecond(41), false)),
badgeCaptor.getValue());
}
@ParameterizedTest @ParameterizedTest
@EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"}) @EnumSource(value = org.signal.chat.common.IdentityType.class, names = {"IDENTITY_TYPE_ACI", "IDENTITY_TYPE_PNI"})
void getUnversionedProfile(final IdentityType identityType) { void getUnversionedProfile(final IdentityType identityType) {

View File

@ -31,8 +31,8 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AccountsHelper { public class AccountsHelper {
@ -62,6 +62,71 @@ public class AccountsHelper {
setupMockUpdate(mockAccountsManager, false); setupMockUpdate(mockAccountsManager, false);
} }
/**
* Sets up stubbing for:
* <ul>
* <li>{@link AccountsManager#update(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateAsync(Account, Consumer)}</li>
* <li>{@link AccountsManager#updateDevice(Account, byte, Consumer)}</li>
* <li>{@link AccountsManager#updateDeviceAsync(Account, byte, Consumer)}</li>
* </ul>
*
* with multiple calls to the {@link Consumer<Account>}. This simulates retries from {@link org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException}.
* Callers will typically set up stubbing for relevant {@link Account} methods with multiple {@link org.mockito.stubbing.OngoingStubbing#thenReturn(Object)}
* calls:
* <pre>
* // example stubbing
* when(account.getNextDeviceId())
* .thenReturn(2)
* .thenReturn(3);
* </pre>
*/
@SuppressWarnings("unchecked")
public static void setupMockUpdateWithRetries(final AccountsManager mockAccountsManager, final int retryCount) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateAsync(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
for (int i = 0; i < retryCount; i++) {
answer.getArgument(1, Consumer.class).accept(account);
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return copyAndMarkStale(account);
});
when(mockAccountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
for (int i = 0; i < retryCount; i++) {
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
}
return CompletableFuture.completedFuture(copyAndMarkStale(account));
});
}
@SuppressWarnings("unchecked")
private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) { private static void setupMockUpdate(final AccountsManager mockAccountsManager, final boolean markStale) {
when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> { when(mockAccountsManager.update(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class); final Account account = answer.getArgument(0, Account.class);