From 3a1ecb342fa51d4334181aee3a1ab56eff93de86 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Tue, 2 Apr 2024 13:01:56 -0500 Subject: [PATCH] allow striping clients in FaultTolerantHttpClient --- .../backup/Cdn3RemoteStorageManager.java | 4 +- .../Cdn3StorageManagerConfiguration.java | 10 ++- .../http/FaultTolerantHttpClient.java | 74 ++++++++++++++----- .../backup/Cdn3RemoteStorageManagerTest.java | 3 +- .../http/FaultTolerantHttpClientTest.java | 55 +++++++++++++- 5 files changed, 124 insertions(+), 22 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java index 525836855..2508a3b32 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java @@ -86,6 +86,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { .withConnectTimeout(Duration.ofSeconds(10)) .withVersion(HttpClient.Version.HTTP_2) .withTrustedServerCertificates(cdnCaCertificates.toArray(new String[0])) + .withNumClients(configuration.numHttpClients()) .build(); // Client used for calls to storage-manager @@ -98,6 +99,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { .withRetry(retryConfiguration) .withConnectTimeout(Duration.ofSeconds(10)) .withVersion(HttpClient.Version.HTTP_2) + .withNumClients(configuration.numHttpClients()) .build(); } @@ -164,7 +166,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { if (actualSourceLength != expectedSourceLength) { throw new InvalidLengthException( - "Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength); + "Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength); } final int expectedEncryptedLength = encrypter.outputSize(expectedSourceLength); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/Cdn3StorageManagerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/Cdn3StorageManagerConfiguration.java index 8b2e0145c..3462ad9c4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/Cdn3StorageManagerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/Cdn3StorageManagerConfiguration.java @@ -6,4 +6,12 @@ import javax.validation.constraints.NotNull; public record Cdn3StorageManagerConfiguration( @NotNull String baseUri, @NotNull String clientId, - @NotNull SecretString clientSecret) {} + @NotNull SecretString clientSecret, + @NotNull Integer numHttpClients) { + + public Cdn3StorageManagerConfiguration { + if (numHttpClients == null) { + numHttpClients = 2; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java index ef729c136..cac822e86 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClient.java @@ -15,12 +15,15 @@ import java.net.http.HttpResponse; import java.security.KeyStore; import java.security.cert.CertificateException; import java.time.Duration; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.IntStream; import org.glassfish.jersey.SslConfigurator; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; @@ -30,7 +33,7 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; public class FaultTolerantHttpClient { - private final HttpClient httpClient; + private final List httpClients; private final Duration defaultRequestTimeout; private final ScheduledExecutorService retryExecutor; private final Retry retry; @@ -44,11 +47,11 @@ public class FaultTolerantHttpClient { } @VisibleForTesting - FaultTolerantHttpClient(String name, HttpClient httpClient, ScheduledExecutorService retryExecutor, + FaultTolerantHttpClient(String name, List httpClients, ScheduledExecutorService retryExecutor, Duration defaultRequestTimeout, RetryConfiguration retryConfiguration, final Predicate retryOnException, CircuitBreakerConfiguration circuitBreakerConfiguration) { - this.httpClient = httpClient; + this.httpClients = httpClients; this.retryExecutor = retryExecutor; this.defaultRequestTimeout = defaultRequestTimeout; this.breaker = CircuitBreaker.of(name + "-breaker", circuitBreakerConfiguration.toCircuitBreakerConfig()); @@ -71,14 +74,19 @@ public class FaultTolerantHttpClient { } } - public CompletableFuture> sendAsync(HttpRequest request, HttpResponse.BodyHandler bodyHandler) { + private HttpClient httpClient() { + return this.httpClients.get(ThreadLocalRandom.current().nextInt(this.httpClients.size())); + } + + public CompletableFuture> sendAsync(HttpRequest request, + HttpResponse.BodyHandler bodyHandler) { if (request.timeout().isEmpty()) { request = HttpRequest.newBuilder(request, (n, v) -> true) .timeout(defaultRequestTimeout) .build(); } - Supplier>> asyncRequest = sendAsync(httpClient, request, bodyHandler); + Supplier>> asyncRequest = sendAsync(httpClient(), request, bodyHandler); if (retry != null) { return breaker.executeCompletionStage(retryableCompletionStage(asyncRequest)).toCompletableFuture(); @@ -91,7 +99,8 @@ public class FaultTolerantHttpClient { return () -> retry.executeCompletionStage(retryExecutor, supplier); } - private Supplier>> sendAsync(HttpClient client, HttpRequest request, HttpResponse.BodyHandler bodyHandler) { + private Supplier>> sendAsync(HttpClient client, HttpRequest request, + HttpResponse.BodyHandler bodyHandler) { return () -> client.sendAsync(request, bodyHandler); } @@ -101,6 +110,7 @@ public class FaultTolerantHttpClient { private HttpClient.Redirect redirect = HttpClient.Redirect.NEVER; private Duration connectTimeout = Duration.ofSeconds(10); private Duration requestTimeout = Duration.ofSeconds(60); + private int numClients = 1; private String name; private Executor executor; @@ -174,26 +184,54 @@ public class FaultTolerantHttpClient { return this; } + /** + * Specify that the HttpClient should stripe requests across multiple HTTP clients + *

+ * A {@link java.net.http.HttpClient} configured to use HTTP/2 will open a single connection per target host and + * will send concurrent requests to that host over the same connection. If the target host has set a low HTTP/2 + * MAX_CONCURRENT_STREAMS, at MAX_CONCURRENT_STREAMS concurrent requests the client will throw IOExceptions. + *

+ * To use a higher parallelism than the host sets per connection, setting a higher numClients will increase the + * number of connections we make to the backing server. Each request will be assigned to a random client. + *

+ * This builder will refuse to {@link #build()} if the HTTP version is not HTTP/2 + * + * @param numClients The number of underlying HTTP clients to use + * @return {@code this} + */ + public Builder withNumClients(final int numClients) { + this.numClients = numClients; + return this; + } + public FaultTolerantHttpClient build() { if (this.circuitBreakerConfiguration == null || this.name == null || this.executor == null) { throw new IllegalArgumentException("Must specify circuit breaker config, name, and executor"); } - final HttpClient.Builder builder = HttpClient.newBuilder() - .connectTimeout(connectTimeout) - .followRedirects(redirect) - .version(version) - .executor(executor); - - final SslConfigurator sslConfigurator = SslConfigurator.newInstance().securityProtocol(securityProtocol); - - if (this.trustStore != null) { - sslConfigurator.trustStore(trustStore); + if (numClients > 1 && version != HttpClient.Version.HTTP_2) { + throw new IllegalArgumentException("Should not use additional HTTP clients unless using HTTP/2"); } - builder.sslContext(sslConfigurator.createSSLContext()); + final List httpClients = IntStream + .range(0, numClients) + .mapToObj(i -> { + final HttpClient.Builder builder = HttpClient.newBuilder() + .connectTimeout(connectTimeout) + .followRedirects(redirect) + .version(version) + .executor(executor); - return new FaultTolerantHttpClient(name, builder.build(), retryExecutor, requestTimeout, retryConfiguration, + final SslConfigurator sslConfigurator = SslConfigurator.newInstance().securityProtocol(securityProtocol); + + if (this.trustStore != null) { + sslConfigurator.trustStore(trustStore); + } + builder.sslContext(sslConfigurator.createSSLContext()); + return builder.build(); + }).toList(); + + return new FaultTolerantHttpClient(name, httpClients, retryExecutor, requestTimeout, retryConfiguration, retryOnException, circuitBreakerConfiguration); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManagerTest.java index c2e2a2d2e..40b0bd9e3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManagerTest.java @@ -76,7 +76,8 @@ public class Cdn3RemoteStorageManagerTest { new Cdn3StorageManagerConfiguration( wireMock.url("storage-manager/"), "clientId", - new SecretString("clientSecret"))); + new SecretString("clientSecret"), + 2)); wireMock.stubFor(get(urlEqualTo("/cdn2/source/small")) .willReturn(aResponse() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClientTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClientTest.java index f1380dfff..c0a6b52ef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClientTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/http/FaultTolerantHttpClientTest.java @@ -22,21 +22,27 @@ import io.github.resilience4j.circuitbreaker.CallNotPermittedException; import java.io.IOException; import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpHeaders; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import javax.net.ssl.SSLSession; class FaultTolerantHttpClientTest { @@ -126,7 +132,7 @@ class FaultTolerantHttpClientTest { final HttpClient mockHttpClient = mock(HttpClient.class); final FaultTolerantHttpClient client = new FaultTolerantHttpClient( "test", - mockHttpClient, + List.of(mockHttpClient), retryExecutor, Duration.ofSeconds(1), new RetryConfiguration(), @@ -150,6 +156,53 @@ class FaultTolerantHttpClientTest { verify(mockHttpClient, times(3)).sendAsync(any(), any()); } + @Test + void testMultipleClients() throws IOException, InterruptedException { + final HttpClient mockHttpClient1 = mock(HttpClient.class); + final HttpClient mockHttpClient2 = mock(HttpClient.class); + final FaultTolerantHttpClient client = new FaultTolerantHttpClient( + "test", + List.of(mockHttpClient1, mockHttpClient2), + retryExecutor, + Duration.ofSeconds(1), + new RetryConfiguration(), + throwable -> throwable instanceof IOException, + new CircuitBreakerConfiguration()); + + // Just to get a dummy HttpResponse + wireMock.stubFor(get(urlEqualTo("/ping")) + .willReturn(aResponse() + .withHeader("Content-Type", "text/plain") + .withBody("Pong!"))); + + final HttpRequest request = HttpRequest.newBuilder() + .uri(URI.create("http://localhost:" + wireMock.getPort() + "/ping")) + .GET() + .build(); + final HttpResponse response = HttpClient.newHttpClient().send(request, HttpResponse.BodyHandlers.discarding()); + + final AtomicInteger client1Calls = new AtomicInteger(0); + final AtomicInteger client2Calls = new AtomicInteger(0); + when(mockHttpClient1.sendAsync(any(), any())) + .thenAnswer(args -> { + client1Calls.incrementAndGet(); + return CompletableFuture.completedFuture(response); + }); + when(mockHttpClient2.sendAsync(any(), any())) + .thenAnswer(args -> { + client2Calls.incrementAndGet(); + return CompletableFuture.completedFuture(response); + }); + + final int numCalls = 100; + for (int i = 0; i < numCalls; i++) { + client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join(); + } + assertThat(client2Calls.get()).isGreaterThan(0); + assertThat(client1Calls.get()).isGreaterThan(0); + assertThat(client1Calls.get() + client2Calls.get()).isEqualTo(numCalls); + } + @Test void testNetworkFailureCircuitBreaker() throws InterruptedException { CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration();