Make identity token fetcher more async friendly

After the identity token expires a subsequent call would do a blocking
operation to retrieve the new token. Since we're making use of an async
gRPC client, this tends to block a thread we don't want to be blocking
on.

Instead, switch to periodically refreshing the token on a dedicated
thread.
This commit is contained in:
Ravi Khadiwala 2024-01-23 16:55:57 -06:00 committed by ravi-signal
parent 498ace0488
commit 1428ca73de
4 changed files with 160 additions and 29 deletions

View File

@ -436,6 +436,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.scheduledExecutorService(name(getClass(), "hCaptchaRetry-%d")).threads(1).build();
ScheduledExecutorService remoteStorageExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "remoteStorageRetry-%d")).threads(1).build();
ScheduledExecutorService registrationIdentityTokenRefreshExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "registrationIdentityTokenRefresh-%d")).threads(1).build();
Scheduler messageDeliveryScheduler = Schedulers.fromExecutorService(
ExecutorServiceMetrics.monitor(Metrics.globalRegistry,
@ -523,7 +525,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
: config.getRegistrationServiceConfiguration().credentialConfigurationJson(),
config.getRegistrationServiceConfiguration().identityTokenAudience(),
config.getRegistrationServiceConfiguration().registrationCaCertificate(),
registrationCallbackExecutor);
registrationCallbackExecutor,
registrationIdentityTokenRefreshExecutor);
SecureValueRecovery2Client secureValueRecovery2Client = new SecureValueRecovery2Client(svr2CredentialsGenerator,
secureValueRecoveryServiceExecutor, secureValueRecoveryServiceRetryExecutor, config.getSvr2Configuration());
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator,

View File

@ -1,11 +1,19 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.registration;
import com.google.auth.oauth2.ExternalAccountCredentials;
import com.google.auth.oauth2.ImpersonatedCredentials;
import com.google.common.base.Suppliers;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.Retry;
import io.github.resilience4j.retry.RetryConfig;
import io.grpc.CallCredentials;
import io.grpc.Metadata;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
@ -13,56 +21,100 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class IdentityTokenCallCredentials extends CallCredentials {
private final Supplier<String> identityTokenSupplier;
class IdentityTokenCallCredentials extends CallCredentials implements Closeable {
private static final Duration IDENTITY_TOKEN_LIFETIME = Duration.ofHours(1);
private static final Duration IDENTITY_TOKEN_REFRESH_BUFFER = Duration.ofMinutes(10);
private static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);
private static final Logger logger = LoggerFactory.getLogger(IdentityTokenCallCredentials.class);
IdentityTokenCallCredentials(final Supplier<String> identityTokenSupplier) {
this.identityTokenSupplier = identityTokenSupplier;
private final Retry retry;
private final ImpersonatedCredentials impersonatedCredentials;
private final String audience;
private final ScheduledFuture<?> scheduledFuture;
private volatile Pair<String, RuntimeException> currentIdentityToken;
IdentityTokenCallCredentials(
final RetryConfig retryConfig,
final ImpersonatedCredentials impersonatedCredentials,
final String audience,
final ScheduledExecutorService scheduledExecutorService) {
this.impersonatedCredentials = impersonatedCredentials;
this.audience = audience;
this.retry = Retry.of("identity-token-fetch", retryConfig);
scheduledFuture = scheduledExecutorService.scheduleAtFixedRate(this::refreshIdentityToken,
IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(),
IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(),
TimeUnit.MILLISECONDS);
}
static IdentityTokenCallCredentials fromCredentialConfig(final String credentialConfigJson, final String audience) throws IOException {
try (final InputStream configInputStream = new ByteArrayInputStream(credentialConfigJson.getBytes(StandardCharsets.UTF_8))) {
static IdentityTokenCallCredentials fromCredentialConfig(
final String credentialConfigJson,
final String audience,
final ScheduledExecutorService scheduledExecutorService) throws IOException {
try (final InputStream configInputStream = new ByteArrayInputStream(
credentialConfigJson.getBytes(StandardCharsets.UTF_8))) {
final ExternalAccountCredentials credentials = ExternalAccountCredentials.fromStream(configInputStream);
final ImpersonatedCredentials impersonatedCredentials = ImpersonatedCredentials.create(credentials,
credentials.getServiceAccountEmail(), null, List.of(), (int) IDENTITY_TOKEN_LIFETIME.toSeconds());
final Supplier<String> idTokenSupplier = Suppliers.memoizeWithExpiration(() -> {
try {
impersonatedCredentials.getSourceCredentials().refresh();
return impersonatedCredentials.idTokenWithAudience(audience, null).getTokenValue();
} catch (final IOException e) {
logger.warn("Failed to retrieve identity token", e);
throw new UncheckedIOException(e);
}
},
IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(),
TimeUnit.MILLISECONDS);
final IdentityTokenCallCredentials identityTokenCallCredentials = new IdentityTokenCallCredentials(
RetryConfig.custom()
.retryOnException(throwable -> true)
.maxAttempts(Integer.MAX_VALUE)
.intervalFunction(IntervalFunction.ofExponentialRandomBackoff(
Duration.ofMillis(100), 1.5, Duration.ofSeconds(5)))
.build(), impersonatedCredentials, audience, scheduledExecutorService);
return new IdentityTokenCallCredentials(idTokenSupplier);
// Make sure credentials are initially populated
identityTokenCallCredentials.refreshIdentityToken();
return identityTokenCallCredentials;
}
}
@VisibleForTesting
void refreshIdentityToken() {
retry.executeRunnable(() -> {
try {
impersonatedCredentials.getSourceCredentials().refresh();
this.currentIdentityToken = Pair.of(
impersonatedCredentials.idTokenWithAudience(audience, null).getTokenValue(),
null);
} catch (final IOException e) {
logger.warn("Failed to retrieve identity token", e);
final UncheckedIOException wrapped = new UncheckedIOException(e);
this.currentIdentityToken = Pair.of(null, wrapped);
throw wrapped;
} catch (final RuntimeException e) {
logger.error("Failed to retrieve identity token", e);
this.currentIdentityToken = Pair.of(null, e);
throw e;
}
});
}
@Override
public void applyRequestMetadata(final RequestInfo requestInfo,
final Executor appExecutor,
final MetadataApplier applier) {
@Nullable final String identityTokenValue = identityTokenSupplier.get();
final Pair<String, RuntimeException> pair = currentIdentityToken;
if (pair.getRight() != null) {
throw pair.getRight();
}
final String identityTokenValue = pair.getLeft();
if (identityTokenValue != null) {
final Metadata metadata = new Metadata();
@ -75,4 +127,13 @@ class IdentityTokenCallCredentials extends CallCredentials {
@Override
public void thisUsesUnstableApi() {
}
@Override
public void close() {
synchronized (this) {
if (!scheduledFuture.isDone()) {
scheduledFuture.cancel(true);
}
}
}
}

View File

@ -21,6 +21,7 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
@ -37,6 +38,7 @@ import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
public class RegistrationServiceClient implements Managed {
private final ManagedChannel channel;
private final IdentityTokenCallCredentials identityTokenCallCredentials;
private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub;
private final Executor callbackExecutor;
@ -62,7 +64,8 @@ public class RegistrationServiceClient implements Managed {
final String credentialConfigJson,
final String identityTokenAudience,
final String caCertificatePem,
final Executor callbackExecutor) throws IOException {
final Executor callbackExecutor,
final ScheduledExecutorService identityRefreshExecutor) throws IOException {
try (final ByteArrayInputStream certificateInputStream = new ByteArrayInputStream(caCertificatePem.getBytes(StandardCharsets.UTF_8))) {
final ChannelCredentials tlsChannelCredentials = TlsChannelCredentials.newBuilder()
@ -74,8 +77,10 @@ public class RegistrationServiceClient implements Managed {
.build();
}
this.stub = RegistrationServiceGrpc.newFutureStub(channel)
.withCallCredentials(IdentityTokenCallCredentials.fromCredentialConfig(credentialConfigJson, identityTokenAudience));
this.identityTokenCallCredentials = IdentityTokenCallCredentials.fromCredentialConfig(
credentialConfigJson, identityTokenAudience, identityRefreshExecutor);
this.stub = RegistrationServiceGrpc.newFutureStub(channel).withCallCredentials(identityTokenCallCredentials);
this.callbackExecutor = callbackExecutor;
}
@ -279,5 +284,6 @@ public class RegistrationServiceClient implements Managed {
if (channel != null) {
channel.shutdown();
}
this.identityTokenCallCredentials.close();
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.registration;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.IdToken;
import com.google.auth.oauth2.ImpersonatedCredentials;
import io.github.resilience4j.core.IntervalFunction;
import io.github.resilience4j.retry.RetryConfig;
import io.grpc.CallCredentials;
import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.Executors;
import org.junit.jupiter.api.Test;
public class IdentityTokenCallCredentialsTest {
@Test
public void retryErrors() throws IOException {
final ImpersonatedCredentials impersonatedCredentials = mock(ImpersonatedCredentials.class);
when(impersonatedCredentials.getSourceCredentials()).thenReturn(mock(GoogleCredentials.class));
final IdentityTokenCallCredentials creds = new IdentityTokenCallCredentials(
RetryConfig.custom()
.retryOnException(throwable -> true)
.maxAttempts(Integer.MAX_VALUE)
.intervalFunction(IntervalFunction.ofExponentialRandomBackoff(
Duration.ofMillis(100), 1.5, Duration.ofSeconds(5)))
.build(),
impersonatedCredentials,
"test",
Executors.newSingleThreadScheduledExecutor());
final IdToken idToken = mock(IdToken.class);
when(idToken.getTokenValue()).thenReturn("testtoken");
// throw exception first two calls, then succeed
when(impersonatedCredentials.idTokenWithAudience(anyString(), any()))
.thenThrow(new IOException("uh oh 1"))
.thenThrow(new IOException("uh oh 2"))
.thenReturn(idToken)
.thenThrow(new IOException("uh oh 3"));
creds.refreshIdentityToken();
CallCredentials.MetadataApplier metadataApplier = mock(CallCredentials.MetadataApplier.class);
creds.applyRequestMetadata(null, null, metadataApplier);
verify(metadataApplier, times(1))
.apply(argThat(metadata -> "Bearer testtoken".equals(metadata.get(IdentityTokenCallCredentials.AUTHORIZATION_METADATA_KEY))));
}
}