diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index cf8fc727f..a4eb035b8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -436,6 +436,8 @@ public class WhisperServerService extends Application identityTokenSupplier; - +class IdentityTokenCallCredentials extends CallCredentials implements Closeable { private static final Duration IDENTITY_TOKEN_LIFETIME = Duration.ofHours(1); private static final Duration IDENTITY_TOKEN_REFRESH_BUFFER = Duration.ofMinutes(10); - private static final Metadata.Key AUTHORIZATION_METADATA_KEY = + static final Metadata.Key AUTHORIZATION_METADATA_KEY = Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER); private static final Logger logger = LoggerFactory.getLogger(IdentityTokenCallCredentials.class); - IdentityTokenCallCredentials(final Supplier identityTokenSupplier) { - this.identityTokenSupplier = identityTokenSupplier; + private final Retry retry; + private final ImpersonatedCredentials impersonatedCredentials; + private final String audience; + private final ScheduledFuture scheduledFuture; + + private volatile Pair currentIdentityToken; + + IdentityTokenCallCredentials( + final RetryConfig retryConfig, + final ImpersonatedCredentials impersonatedCredentials, + final String audience, + final ScheduledExecutorService scheduledExecutorService) { + this.impersonatedCredentials = impersonatedCredentials; + this.audience = audience; + this.retry = Retry.of("identity-token-fetch", retryConfig); + scheduledFuture = scheduledExecutorService.scheduleAtFixedRate(this::refreshIdentityToken, + IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(), + IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(), + TimeUnit.MILLISECONDS); } - static IdentityTokenCallCredentials fromCredentialConfig(final String credentialConfigJson, final String audience) throws IOException { - try (final InputStream configInputStream = new ByteArrayInputStream(credentialConfigJson.getBytes(StandardCharsets.UTF_8))) { + static IdentityTokenCallCredentials fromCredentialConfig( + final String credentialConfigJson, + final String audience, + final ScheduledExecutorService scheduledExecutorService) throws IOException { + try (final InputStream configInputStream = new ByteArrayInputStream( + credentialConfigJson.getBytes(StandardCharsets.UTF_8))) { final ExternalAccountCredentials credentials = ExternalAccountCredentials.fromStream(configInputStream); final ImpersonatedCredentials impersonatedCredentials = ImpersonatedCredentials.create(credentials, credentials.getServiceAccountEmail(), null, List.of(), (int) IDENTITY_TOKEN_LIFETIME.toSeconds()); - final Supplier idTokenSupplier = Suppliers.memoizeWithExpiration(() -> { - try { - impersonatedCredentials.getSourceCredentials().refresh(); - return impersonatedCredentials.idTokenWithAudience(audience, null).getTokenValue(); - } catch (final IOException e) { - logger.warn("Failed to retrieve identity token", e); - throw new UncheckedIOException(e); - } - }, - IDENTITY_TOKEN_LIFETIME.minus(IDENTITY_TOKEN_REFRESH_BUFFER).toMillis(), - TimeUnit.MILLISECONDS); + final IdentityTokenCallCredentials identityTokenCallCredentials = new IdentityTokenCallCredentials( + RetryConfig.custom() + .retryOnException(throwable -> true) + .maxAttempts(Integer.MAX_VALUE) + .intervalFunction(IntervalFunction.ofExponentialRandomBackoff( + Duration.ofMillis(100), 1.5, Duration.ofSeconds(5))) + .build(), impersonatedCredentials, audience, scheduledExecutorService); - return new IdentityTokenCallCredentials(idTokenSupplier); + // Make sure credentials are initially populated + identityTokenCallCredentials.refreshIdentityToken(); + + return identityTokenCallCredentials; } } + @VisibleForTesting + void refreshIdentityToken() { + retry.executeRunnable(() -> { + try { + impersonatedCredentials.getSourceCredentials().refresh(); + this.currentIdentityToken = Pair.of( + impersonatedCredentials.idTokenWithAudience(audience, null).getTokenValue(), + null); + } catch (final IOException e) { + logger.warn("Failed to retrieve identity token", e); + final UncheckedIOException wrapped = new UncheckedIOException(e); + this.currentIdentityToken = Pair.of(null, wrapped); + throw wrapped; + } catch (final RuntimeException e) { + logger.error("Failed to retrieve identity token", e); + this.currentIdentityToken = Pair.of(null, e); + throw e; + } + }); + } + @Override public void applyRequestMetadata(final RequestInfo requestInfo, final Executor appExecutor, final MetadataApplier applier) { - @Nullable final String identityTokenValue = identityTokenSupplier.get(); + final Pair pair = currentIdentityToken; + if (pair.getRight() != null) { + throw pair.getRight(); + } + + final String identityTokenValue = pair.getLeft(); if (identityTokenValue != null) { final Metadata metadata = new Metadata(); @@ -75,4 +127,13 @@ class IdentityTokenCallCredentials extends CallCredentials { @Override public void thisUsesUnstableApi() { } + + @Override + public void close() { + synchronized (this) { + if (!scheduledFuture.isDone()) { + scheduledFuture.cancel(true); + } + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java index ed8eecfb5..d52d25d8b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java @@ -21,6 +21,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.StringUtils; import org.checkerframework.checker.nullness.qual.Nullable; @@ -37,6 +38,7 @@ import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; public class RegistrationServiceClient implements Managed { private final ManagedChannel channel; + private final IdentityTokenCallCredentials identityTokenCallCredentials; private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub; private final Executor callbackExecutor; @@ -62,7 +64,8 @@ public class RegistrationServiceClient implements Managed { final String credentialConfigJson, final String identityTokenAudience, final String caCertificatePem, - final Executor callbackExecutor) throws IOException { + final Executor callbackExecutor, + final ScheduledExecutorService identityRefreshExecutor) throws IOException { try (final ByteArrayInputStream certificateInputStream = new ByteArrayInputStream(caCertificatePem.getBytes(StandardCharsets.UTF_8))) { final ChannelCredentials tlsChannelCredentials = TlsChannelCredentials.newBuilder() @@ -74,8 +77,10 @@ public class RegistrationServiceClient implements Managed { .build(); } - this.stub = RegistrationServiceGrpc.newFutureStub(channel) - .withCallCredentials(IdentityTokenCallCredentials.fromCredentialConfig(credentialConfigJson, identityTokenAudience)); + this.identityTokenCallCredentials = IdentityTokenCallCredentials.fromCredentialConfig( + credentialConfigJson, identityTokenAudience, identityRefreshExecutor); + + this.stub = RegistrationServiceGrpc.newFutureStub(channel).withCallCredentials(identityTokenCallCredentials); this.callbackExecutor = callbackExecutor; } @@ -279,5 +284,6 @@ public class RegistrationServiceClient implements Managed { if (channel != null) { channel.shutdown(); } + this.identityTokenCallCredentials.close(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/registration/IdentityTokenCallCredentialsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/registration/IdentityTokenCallCredentialsTest.java new file mode 100644 index 000000000..f012c1f17 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/registration/IdentityTokenCallCredentialsTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.registration; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.IdToken; +import com.google.auth.oauth2.ImpersonatedCredentials; +import io.github.resilience4j.core.IntervalFunction; +import io.github.resilience4j.retry.RetryConfig; +import io.grpc.CallCredentials; +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.Executors; +import org.junit.jupiter.api.Test; + +public class IdentityTokenCallCredentialsTest { + + @Test + public void retryErrors() throws IOException { + final ImpersonatedCredentials impersonatedCredentials = mock(ImpersonatedCredentials.class); + when(impersonatedCredentials.getSourceCredentials()).thenReturn(mock(GoogleCredentials.class)); + + final IdentityTokenCallCredentials creds = new IdentityTokenCallCredentials( + RetryConfig.custom() + .retryOnException(throwable -> true) + .maxAttempts(Integer.MAX_VALUE) + .intervalFunction(IntervalFunction.ofExponentialRandomBackoff( + Duration.ofMillis(100), 1.5, Duration.ofSeconds(5))) + .build(), + impersonatedCredentials, + "test", + Executors.newSingleThreadScheduledExecutor()); + + final IdToken idToken = mock(IdToken.class); + when(idToken.getTokenValue()).thenReturn("testtoken"); + + // throw exception first two calls, then succeed + when(impersonatedCredentials.idTokenWithAudience(anyString(), any())) + .thenThrow(new IOException("uh oh 1")) + .thenThrow(new IOException("uh oh 2")) + .thenReturn(idToken) + .thenThrow(new IOException("uh oh 3")); + + creds.refreshIdentityToken(); + CallCredentials.MetadataApplier metadataApplier = mock(CallCredentials.MetadataApplier.class); + creds.applyRequestMetadata(null, null, metadataApplier); + verify(metadataApplier, times(1)) + .apply(argThat(metadata -> "Bearer testtoken".equals(metadata.get(IdentityTokenCallCredentials.AUTHORIZATION_METADATA_KEY)))); + } +}