Remove `ForkJoinPool.managedBlock` in favor of async updates

This commit is contained in:
Chris Eager 2023-12-12 11:01:24 -06:00 committed by Jon Chambers
parent 28a981f29f
commit 8d4acf0330
3 changed files with 29 additions and 35 deletions

View File

@ -10,11 +10,8 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinPool.ManagedBlocker;
import java.util.function.Function; import java.util.function.Function;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.validation.Valid; import javax.validation.Valid;
@ -36,7 +33,6 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest; import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge; import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
@ -101,43 +97,24 @@ public class DonationController {
if (badgeId == null) { if (badgeId == null) {
return CompletableFuture.completedFuture(Response.serverError().entity("server does not recognize the requested receipt level").type(MediaType.TEXT_PLAIN_TYPE).build()); return CompletableFuture.completedFuture(Response.serverError().entity("server does not recognize the requested receipt level").type(MediaType.TEXT_PLAIN_TYPE).build());
} }
final CompletionStage<Boolean> putStage = redeemedReceiptsManager.put( return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid()); receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
return putStage.thenApplyAsync(receiptMatched -> { .thenCompose(receiptMatched -> {
if (!receiptMatched) { if (!receiptMatched) {
return Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed").type(MediaType.TEXT_PLAIN_TYPE).build(); return CompletableFuture.completedFuture(
Response.status(Status.BAD_REQUEST).entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE).build());
} }
try { return accountsManager.getByAccountIdentifierAsync(auth.getAccount().getUuid())
ForkJoinPool.managedBlock(new ManagedBlocker() { .thenCompose(optionalAccount ->
boolean done = false; optionalAccount.map(account -> accountsManager.updateAsync(account, a -> {
@Override
public boolean block() {
final Optional<Account> optionalAccount = accountsManager.getByAccountIdentifier(auth.getAccount().getUuid());
optionalAccount.ifPresent(account -> {
accountsManager.update(account, a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) { if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId); a.makeBadgePrimaryIfExists(clock, badgeId);
} }
}); })).orElse(CompletableFuture.completedFuture(null)))
}); .thenApply(ignored -> Response.ok().build());
done = true;
return true;
}
@Override
public boolean isReleasable() {
return done;
}
});
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return Response.serverError().build();
}
return Response.ok().build();
}); });
}).thenCompose(Function.identity()); }).thenCompose(Function.identity());
} }

View File

@ -119,7 +119,8 @@ class DonationControllerTest {
when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration); when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration);
when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn( when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Boolean.TRUE)); CompletableFuture.completedFuture(Boolean.TRUE));
when(accountsManager.getByAccountIdentifier(eq(AuthHelper.VALID_UUID))).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifierAsync(eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true); RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()

View File

@ -20,6 +20,7 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.mockito.MockingDetails; import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing; import org.mockito.stubbing.Stubbing;
@ -69,6 +70,13 @@ public class AccountsHelper {
return markStale ? copyAndMarkStale(account) : account; return markStale ? copyAndMarkStale(account) : account;
}); });
when(mockAccountsManager.updateAsync(any(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
answer.getArgument(1, Consumer.class).accept(account);
return CompletableFuture.completedFuture(markStale ? copyAndMarkStale(account) : account);
});
when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> { when(mockAccountsManager.updateDevice(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class); final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class); final byte deviceId = answer.getArgument(1, Byte.class);
@ -77,6 +85,14 @@ public class AccountsHelper {
return markStale ? copyAndMarkStale(account) : account; return markStale ? copyAndMarkStale(account) : account;
}); });
when(mockAccountsManager.updateDeviceAsync(any(), anyByte(), any())).thenAnswer(answer -> {
final Account account = answer.getArgument(0, Account.class);
final byte deviceId = answer.getArgument(1, Byte.class);
account.getDevice(deviceId).ifPresent(answer.getArgument(2, Consumer.class));
return CompletableFuture.completedFuture(markStale ? copyAndMarkStale(account) : account);
});
when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(answer -> { when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(answer -> {
answer.getArgument(1, Device.class).setLastSeen(answer.getArgument(2, Long.class)); answer.getArgument(1, Device.class).setLastSeen(answer.getArgument(2, Long.class));
return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {}); return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {});