diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java index e0d007c78..12ad9243b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreaker.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import java.net.SocketAddress; +import java.util.Collection; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -195,9 +196,6 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer { logger.trace("Breaker state is {}", breaker.getState()); - // Note: io.lettuce.core.protocol.CommandHandler also supports batches (List/Collection), - // but we do not use that feature, so we can just check for single commands - // // There are two types of RedisCommands that are not CompleteableCommand: // - io.lettuce.core.protocol.Command // - io.lettuce.core.protocol.PristineFallbackCommand @@ -206,36 +204,55 @@ public class LettuceShardCircuitBreaker implements NettyCustomizer { // to consume responses. if (msg instanceof RedisCommand rc && rc instanceof CompleteableCommand command) { try { - breaker.acquirePermission(); - - // state can change in acquirePermission() - logger.trace("Breaker is permitted: {}", breaker.getState()); - - final long startNanos = System.nanoTime(); - - command.onComplete((ignored, throwable) -> { - final long durationNanos = System.nanoTime() - startNanos; - - // RedisNoScriptException doesn’t indicate a fault the breaker can protect - if (throwable != null && !(throwable instanceof RedisNoScriptException)) { - breaker.onError(durationNanos, TimeUnit.NANOSECONDS, throwable); - logger.warn("Command completed with error", throwable); - } else { - breaker.onSuccess(durationNanos, TimeUnit.NANOSECONDS); - } - }); - + instrumentCommand(command); } catch (final CallNotPermittedException e) { rc.completeExceptionally(e); promise.tryFailure(e); return; } + } else if (msg instanceof Collection collection && + !collection.isEmpty() && + collection.stream().allMatch(obj -> obj instanceof RedisCommand && obj instanceof CompleteableCommand)) { + + @SuppressWarnings("unchecked") final Collection> commandCollection = + (Collection>) collection; + + try { + // If we have a collection of commands, we only acquire a single permit for the whole batch (since there's + // only a single write promise to fail). We choose a single command from the collection to sample for failure. + instrumentCommand((CompleteableCommand) commandCollection.iterator().next()); + } catch (final CallNotPermittedException e) { + commandCollection.forEach(redisCommand -> redisCommand.completeExceptionally(e)); + promise.tryFailure(e); + return; + } } else { logger.warn("Unexpected msg type: {}", msg.getClass()); } super.write(ctx, msg, promise); } + + private void instrumentCommand(final CompleteableCommand command) throws CallNotPermittedException { + breaker.acquirePermission(); + + // state can change in acquirePermission() + logger.trace("Breaker is permitted: {}", breaker.getState()); + + final long startNanos = System.nanoTime(); + + command.onComplete((ignored, throwable) -> { + final long durationNanos = System.nanoTime() - startNanos; + + // RedisNoScriptException doesn’t indicate a fault the breaker can protect + if (throwable != null && !(throwable instanceof RedisNoScriptException)) { + breaker.onError(durationNanos, TimeUnit.NANOSECONDS, throwable); + logger.warn("Command completed with error", throwable); + } else { + breaker.onSuccess(durationNanos, TimeUnit.NANOSECONDS); + } + }); + } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java index 24c6e9b82..155294eb9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/LettuceShardCircuitBreakerTest.java @@ -44,6 +44,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -103,8 +104,8 @@ class LettuceShardCircuitBreakerTest { } @ParameterizedTest - @MethodSource - void testHandlerWriteBreakerClosed(@Nullable final Throwable t) throws Exception { + @ValueSource(booleans = {true, false}) + void testHandlerWriteBreakerClosed(final boolean completeExceptionally) throws Exception { final CircuitBreaker breaker = mock(CircuitBreaker.class); channelCircuitBreakerHandler.breaker = breaker; @@ -116,9 +117,11 @@ class LettuceShardCircuitBreakerTest { verify(breaker).acquirePermission(); - if (t != null) { - command.completeExceptionally(t); - verify(breaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), eq(t)); + if (completeExceptionally) { + final Throwable throwable = new IOException("timeout"); + + command.completeExceptionally(throwable); + verify(breaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), eq(throwable)); } else { command.complete("PONG"); verify(breaker).onSuccess(anyLong(), eq(TimeUnit.NANOSECONDS)); @@ -128,12 +131,36 @@ class LettuceShardCircuitBreakerTest { verify(channelHandlerContext).write(command, channelPromise); } - static List testHandlerWriteBreakerClosed() { - final List errors = new ArrayList<>(); - errors.add(null); - errors.add(new IOException("timeout")); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testHandlerWriteBatchBreakerClosed(final boolean completeExceptionally) throws Exception { + final CircuitBreaker breaker = mock(CircuitBreaker.class); + channelCircuitBreakerHandler.breaker = breaker; - return errors; + final AsyncCommand firstCommand = new AsyncCommand<>( + new Command<>(CommandType.PING, new StatusOutput<>(StringCodec.ASCII))); + final AsyncCommand secondCommand = new AsyncCommand<>( + new Command<>(CommandType.PING, new StatusOutput<>(StringCodec.ASCII))); + final ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); + final ChannelPromise channelPromise = mock(ChannelPromise.class); + channelCircuitBreakerHandler.write(channelHandlerContext, List.of(firstCommand, secondCommand), channelPromise); + + verify(breaker).acquirePermission(); + + if (completeExceptionally) { + final Throwable throwable = new IOException("timeout"); + + firstCommand.completeExceptionally(throwable); + secondCommand.completeExceptionally(throwable); + verify(breaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), eq(throwable)); + } else { + firstCommand.complete("PONG"); + secondCommand.complete("PONG"); + verify(breaker).onSuccess(anyLong(), eq(TimeUnit.NANOSECONDS)); + } + + // write should always be forwarded when the breaker is closed + verify(channelHandlerContext).write(List.of(firstCommand, secondCommand), channelPromise); } @Test @@ -155,4 +182,24 @@ class LettuceShardCircuitBreakerTest { verifyNoInteractions(channelHandlerContext); } + @Test + void testHandlerWriteBatchBreakerOpen() throws Exception { + final CircuitBreaker breaker = mock(CircuitBreaker.class); + channelCircuitBreakerHandler.breaker = breaker; + + final CallNotPermittedException callNotPermittedException = mock(CallNotPermittedException.class); + doThrow(callNotPermittedException).when(breaker).acquirePermission(); + + @SuppressWarnings("unchecked") final AsyncCommand firstCommand = mock(AsyncCommand.class); + @SuppressWarnings("unchecked") final AsyncCommand secondCommand = mock(AsyncCommand.class); + final ChannelHandlerContext channelHandlerContext = mock(ChannelHandlerContext.class); + final ChannelPromise channelPromise = mock(ChannelPromise.class); + channelCircuitBreakerHandler.write(channelHandlerContext, List.of(firstCommand, secondCommand), channelPromise); + + verify(firstCommand).completeExceptionally(callNotPermittedException); + verify(secondCommand).completeExceptionally(callNotPermittedException); + verify(channelPromise).tryFailure(callNotPermittedException); + + verifyNoInteractions(channelHandlerContext); + } }