Fix blocking call in waitForLinkedDevice

This commit is contained in:
Ravi Khadiwala 2025-01-27 19:24:47 -06:00 committed by ravi-signal
parent aae94ffae3
commit 1446d1acf8
2 changed files with 35 additions and 33 deletions

View File

@ -343,7 +343,7 @@ public class DeviceController {
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting") @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletableFuture<Response> waitForLinkedDevice( public CompletionStage<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier") @PathParam("tokenIdentifier")
@ -363,40 +363,35 @@ public class DeviceController {
given amount of time, this endpoint will return a status of HTTP/204. given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds, """) final int timeoutSeconds,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI));
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet(); linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
try {
return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null) { return rateLimiters.getWaitForLinkedDeviceLimiter()
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.publishPercentileHistogram(true) .thenCompose(ignored -> accounts.waitForNewLinkedDevice(
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), authenticatedDevice.getAccount().getUuid(),
io.micrometer.core.instrument.Tag.of("deviceFound", authenticatedDevice.getAuthenticatedDevice(),
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) tokenIdentifier,
.register(Metrics.globalRegistry)); Duration.ofSeconds(timeoutSeconds)))
} .thenApply(maybeDeviceInfo -> maybeDeviceInfo
}); .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
} catch (final RedisException e) { .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
// `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
// if that happens e -> Response.status(Response.Status.BAD_REQUEST).build()))
linkedDeviceListenerCounter.decrementAndGet(); .whenComplete((response, throwable) -> {
throw e; linkedDeviceListenerCounter.decrementAndGet();
}
if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
} }
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {

View File

@ -955,6 +955,8 @@ class DeviceControllerTest {
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request() .request()
@ -979,6 +981,8 @@ class DeviceControllerTest {
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request() .request()
@ -997,6 +1001,8 @@ class DeviceControllerTest {
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request() .request()
@ -1042,10 +1048,11 @@ class DeviceControllerTest {
} }
@Test @Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { void waitForLinkedDeviceRateLimited() {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)