Break up large outbound noise messages
This commit is contained in:
parent
542422b7b8
commit
3d96d73169
|
@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.grpc.net;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherState;
|
import com.southernstorm.noise.protocol.CipherState;
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
|
@ -16,6 +17,7 @@ import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
import io.netty.util.ReferenceCountUtil;
|
||||||
|
import io.netty.util.concurrent.PromiseCombiner;
|
||||||
import io.netty.util.internal.EmptyArrays;
|
import io.netty.util.internal.EmptyArrays;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
@ -163,25 +165,35 @@ abstract class NoiseHandler extends ChannelDuplexHandler {
|
||||||
@Override
|
@Override
|
||||||
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
|
||||||
throws Exception {
|
throws Exception {
|
||||||
if (message instanceof ByteBuf plaintext) {
|
if (message instanceof ByteBuf byteBuf) {
|
||||||
try {
|
try {
|
||||||
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
|
||||||
final CipherState cipherState = cipherStatePair.getSender();
|
final CipherState cipherState = cipherStatePair.getSender();
|
||||||
final int plaintextLength = plaintext.readableBytes();
|
|
||||||
|
|
||||||
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
// Server message might not fit in a single noise packet, break it up into as many chunks as we need
|
||||||
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
final PromiseCombiner pc = new PromiseCombiner(context.executor());
|
||||||
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
while (byteBuf.isReadable()) {
|
||||||
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
final ByteBuf plaintext = byteBuf.readSlice(Math.min(
|
||||||
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
// need room for a 16-byte AEAD tag
|
||||||
|
Noise.MAX_PACKET_LEN - 16,
|
||||||
|
byteBuf.readableBytes()));
|
||||||
|
|
||||||
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
final int plaintextLength = plaintext.readableBytes();
|
||||||
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
|
||||||
|
|
||||||
context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
|
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
|
||||||
|
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
|
||||||
|
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
|
||||||
|
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
|
||||||
|
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
|
||||||
|
|
||||||
|
// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
|
||||||
|
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
|
||||||
|
|
||||||
|
pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
|
||||||
|
}
|
||||||
|
pc.finish(promise);
|
||||||
} finally {
|
} finally {
|
||||||
ReferenceCountUtil.release(plaintext);
|
ReferenceCountUtil.release(byteBuf);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!(message instanceof WebSocketFrame)) {
|
if (!(message instanceof WebSocketFrame)) {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
import com.southernstorm.noise.protocol.CipherStatePair;
|
import com.southernstorm.noise.protocol.CipherStatePair;
|
||||||
|
import com.southernstorm.noise.protocol.Noise;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
|
@ -19,18 +20,22 @@ import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
|
||||||
|
import io.netty.util.ReferenceCountUtil;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
import javax.crypto.AEADBadTagException;
|
import javax.crypto.AEADBadTagException;
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.ShortBufferException;
|
import javax.crypto.ShortBufferException;
|
||||||
import io.netty.util.ReferenceCountUtil;
|
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
|
|
||||||
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
|
|
||||||
|
@ -254,4 +259,29 @@ abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {
|
||||||
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
assertTrue(embeddedChannel.outboundMessages().isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
|
||||||
|
void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
|
||||||
|
final CipherStatePair clientCipherStatePair = doHandshake();
|
||||||
|
final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
|
||||||
|
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));
|
||||||
|
|
||||||
|
final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
|
||||||
|
assertTrue(writePlaintextFuture.isSuccess());
|
||||||
|
|
||||||
|
final byte[] decryptedPlaintext = new byte[plaintextLength];
|
||||||
|
int plaintextOffset = 0;
|
||||||
|
BinaryWebSocketFrame ciphertextFrame;
|
||||||
|
while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) {
|
||||||
|
assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN);
|
||||||
|
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
|
||||||
|
ciphertextFrame.release();
|
||||||
|
plaintextOffset += clientCipherStatePair.getReceiver()
|
||||||
|
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
|
||||||
|
}
|
||||||
|
assertArrayEquals(plaintext, decryptedPlaintext);
|
||||||
|
assertEquals(0, plaintextBuffer.refCnt());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue