allow striping clients in FaultTolerantHttpClient

This commit is contained in:
Ravi Khadiwala 2024-04-02 13:01:56 -05:00 committed by ravi-signal
parent bb0da69c9e
commit 3a1ecb342f
5 changed files with 124 additions and 22 deletions

View File

@ -86,6 +86,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
.withConnectTimeout(Duration.ofSeconds(10))
.withVersion(HttpClient.Version.HTTP_2)
.withTrustedServerCertificates(cdnCaCertificates.toArray(new String[0]))
.withNumClients(configuration.numHttpClients())
.build();
// Client used for calls to storage-manager
@ -98,6 +99,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
.withRetry(retryConfiguration)
.withConnectTimeout(Duration.ofSeconds(10))
.withVersion(HttpClient.Version.HTTP_2)
.withNumClients(configuration.numHttpClients())
.build();
}
@ -164,7 +166,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
if (actualSourceLength != expectedSourceLength) {
throw new InvalidLengthException(
"Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength);
"Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength);
}
final int expectedEncryptedLength = encrypter.outputSize(expectedSourceLength);

View File

@ -6,4 +6,12 @@ import javax.validation.constraints.NotNull;
public record Cdn3StorageManagerConfiguration(
@NotNull String baseUri,
@NotNull String clientId,
@NotNull SecretString clientSecret) {}
@NotNull SecretString clientSecret,
@NotNull Integer numHttpClients) {
public Cdn3StorageManagerConfiguration {
if (numHttpClients == null) {
numHttpClients = 2;
}
}
}

View File

@ -15,12 +15,15 @@ import java.net.http.HttpResponse;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import org.glassfish.jersey.SslConfigurator;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
@ -30,7 +33,7 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
public class FaultTolerantHttpClient {
private final HttpClient httpClient;
private final List<HttpClient> httpClients;
private final Duration defaultRequestTimeout;
private final ScheduledExecutorService retryExecutor;
private final Retry retry;
@ -44,11 +47,11 @@ public class FaultTolerantHttpClient {
}
@VisibleForTesting
FaultTolerantHttpClient(String name, HttpClient httpClient, ScheduledExecutorService retryExecutor,
FaultTolerantHttpClient(String name, List<HttpClient> httpClients, ScheduledExecutorService retryExecutor,
Duration defaultRequestTimeout, RetryConfiguration retryConfiguration,
final Predicate<Throwable> retryOnException, CircuitBreakerConfiguration circuitBreakerConfiguration) {
this.httpClient = httpClient;
this.httpClients = httpClients;
this.retryExecutor = retryExecutor;
this.defaultRequestTimeout = defaultRequestTimeout;
this.breaker = CircuitBreaker.of(name + "-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig());
@ -71,14 +74,19 @@ public class FaultTolerantHttpClient {
}
}
public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> bodyHandler) {
private HttpClient httpClient() {
return this.httpClients.get(ThreadLocalRandom.current().nextInt(this.httpClients.size()));
}
public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request,
HttpResponse.BodyHandler<T> bodyHandler) {
if (request.timeout().isEmpty()) {
request = HttpRequest.newBuilder(request, (n, v) -> true)
.timeout(defaultRequestTimeout)
.build();
}
Supplier<CompletionStage<HttpResponse<T>>> asyncRequest = sendAsync(httpClient, request, bodyHandler);
Supplier<CompletionStage<HttpResponse<T>>> asyncRequest = sendAsync(httpClient(), request, bodyHandler);
if (retry != null) {
return breaker.executeCompletionStage(retryableCompletionStage(asyncRequest)).toCompletableFuture();
@ -91,7 +99,8 @@ public class FaultTolerantHttpClient {
return () -> retry.executeCompletionStage(retryExecutor, supplier);
}
private <T> Supplier<CompletionStage<HttpResponse<T>>> sendAsync(HttpClient client, HttpRequest request, HttpResponse.BodyHandler<T> bodyHandler) {
private <T> Supplier<CompletionStage<HttpResponse<T>>> sendAsync(HttpClient client, HttpRequest request,
HttpResponse.BodyHandler<T> bodyHandler) {
return () -> client.sendAsync(request, bodyHandler);
}
@ -101,6 +110,7 @@ public class FaultTolerantHttpClient {
private HttpClient.Redirect redirect = HttpClient.Redirect.NEVER;
private Duration connectTimeout = Duration.ofSeconds(10);
private Duration requestTimeout = Duration.ofSeconds(60);
private int numClients = 1;
private String name;
private Executor executor;
@ -174,26 +184,54 @@ public class FaultTolerantHttpClient {
return this;
}
/**
* Specify that the HttpClient should stripe requests across multiple HTTP clients
* <p>
* A {@link java.net.http.HttpClient} configured to use HTTP/2 will open a single connection per target host and
* will send concurrent requests to that host over the same connection. If the target host has set a low HTTP/2
* MAX_CONCURRENT_STREAMS, at MAX_CONCURRENT_STREAMS concurrent requests the client will throw IOExceptions.
* <p>
* To use a higher parallelism than the host sets per connection, setting a higher numClients will increase the
* number of connections we make to the backing server. Each request will be assigned to a random client.
* <p>
* This builder will refuse to {@link #build()} if the HTTP version is not HTTP/2
*
* @param numClients The number of underlying HTTP clients to use
* @return {@code this}
*/
public Builder withNumClients(final int numClients) {
this.numClients = numClients;
return this;
}
public FaultTolerantHttpClient build() {
if (this.circuitBreakerConfiguration == null || this.name == null || this.executor == null) {
throw new IllegalArgumentException("Must specify circuit breaker config, name, and executor");
}
final HttpClient.Builder builder = HttpClient.newBuilder()
.connectTimeout(connectTimeout)
.followRedirects(redirect)
.version(version)
.executor(executor);
final SslConfigurator sslConfigurator = SslConfigurator.newInstance().securityProtocol(securityProtocol);
if (this.trustStore != null) {
sslConfigurator.trustStore(trustStore);
if (numClients > 1 && version != HttpClient.Version.HTTP_2) {
throw new IllegalArgumentException("Should not use additional HTTP clients unless using HTTP/2");
}
builder.sslContext(sslConfigurator.createSSLContext());
final List<HttpClient> httpClients = IntStream
.range(0, numClients)
.mapToObj(i -> {
final HttpClient.Builder builder = HttpClient.newBuilder()
.connectTimeout(connectTimeout)
.followRedirects(redirect)
.version(version)
.executor(executor);
return new FaultTolerantHttpClient(name, builder.build(), retryExecutor, requestTimeout, retryConfiguration,
final SslConfigurator sslConfigurator = SslConfigurator.newInstance().securityProtocol(securityProtocol);
if (this.trustStore != null) {
sslConfigurator.trustStore(trustStore);
}
builder.sslContext(sslConfigurator.createSSLContext());
return builder.build();
}).toList();
return new FaultTolerantHttpClient(name, httpClients, retryExecutor, requestTimeout, retryConfiguration,
retryOnException, circuitBreakerConfiguration);
}
}

View File

@ -76,7 +76,8 @@ public class Cdn3RemoteStorageManagerTest {
new Cdn3StorageManagerConfiguration(
wireMock.url("storage-manager/"),
"clientId",
new SecretString("clientSecret")));
new SecretString("clientSecret"),
2));
wireMock.stubFor(get(urlEqualTo("/cdn2/source/small"))
.willReturn(aResponse()

View File

@ -22,21 +22,27 @@ import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import javax.net.ssl.SSLSession;
class FaultTolerantHttpClientTest {
@ -126,7 +132,7 @@ class FaultTolerantHttpClientTest {
final HttpClient mockHttpClient = mock(HttpClient.class);
final FaultTolerantHttpClient client = new FaultTolerantHttpClient(
"test",
mockHttpClient,
List.of(mockHttpClient),
retryExecutor,
Duration.ofSeconds(1),
new RetryConfiguration(),
@ -150,6 +156,53 @@ class FaultTolerantHttpClientTest {
verify(mockHttpClient, times(3)).sendAsync(any(), any());
}
@Test
void testMultipleClients() throws IOException, InterruptedException {
final HttpClient mockHttpClient1 = mock(HttpClient.class);
final HttpClient mockHttpClient2 = mock(HttpClient.class);
final FaultTolerantHttpClient client = new FaultTolerantHttpClient(
"test",
List.of(mockHttpClient1, mockHttpClient2),
retryExecutor,
Duration.ofSeconds(1),
new RetryConfiguration(),
throwable -> throwable instanceof IOException,
new CircuitBreakerConfiguration());
// Just to get a dummy HttpResponse
wireMock.stubFor(get(urlEqualTo("/ping"))
.willReturn(aResponse()
.withHeader("Content-Type", "text/plain")
.withBody("Pong!")));
final HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create("http://localhost:" + wireMock.getPort() + "/ping"))
.GET()
.build();
final HttpResponse response = HttpClient.newHttpClient().send(request, HttpResponse.BodyHandlers.discarding());
final AtomicInteger client1Calls = new AtomicInteger(0);
final AtomicInteger client2Calls = new AtomicInteger(0);
when(mockHttpClient1.sendAsync(any(), any()))
.thenAnswer(args -> {
client1Calls.incrementAndGet();
return CompletableFuture.completedFuture(response);
});
when(mockHttpClient2.sendAsync(any(), any()))
.thenAnswer(args -> {
client2Calls.incrementAndGet();
return CompletableFuture.completedFuture(response);
});
final int numCalls = 100;
for (int i = 0; i < numCalls; i++) {
client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join();
}
assertThat(client2Calls.get()).isGreaterThan(0);
assertThat(client1Calls.get()).isGreaterThan(0);
assertThat(client1Calls.get() + client2Calls.get()).isEqualTo(numCalls);
}
@Test
void testNetworkFailureCircuitBreaker() throws InterruptedException {
CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration();