diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java index c402bce24..48eefac6a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/grpc/net/NoiseHandler.java @@ -6,6 +6,7 @@ import com.southernstorm.noise.protocol.CipherState; import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.Noise; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; @@ -16,6 +17,7 @@ import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.PromiseCombiner; import io.netty.util.internal.EmptyArrays; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -163,25 +165,35 @@ private void fail(final ChannelHandlerContext context, final Throwable cause) { @Override public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise) throws Exception { - if (message instanceof ByteBuf plaintext) { + if (message instanceof ByteBuf byteBuf) { try { // TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames final CipherState cipherState = cipherStatePair.getSender(); - final int plaintextLength = plaintext.readableBytes(); - // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the - // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a - // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. - final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; - plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); + // Server message might not fit in a single noise packet, break it up into as many chunks as we need + final PromiseCombiner pc = new PromiseCombiner(context.executor()); + while (byteBuf.isReadable()) { + final ByteBuf plaintext = byteBuf.readSlice(Math.min( + // need room for a 16-byte AEAD tag + Noise.MAX_PACKET_LEN - 16, + byteBuf.readableBytes())); - // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer - cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); + final int plaintextLength = plaintext.readableBytes(); - context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise); + // We've read these bytes from a local connection; although that likely means they're backed by a heap array, the + // buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a + // mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC. + final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()]; + plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes()); + // Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer + cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength); + + pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)))); + } + pc.finish(promise); } finally { - ReferenceCountUtil.release(plaintext); + ReferenceCountUtil.release(byteBuf); } } else { if (!(message instanceof WebSocketFrame)) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java index aa8c2c6e0..febeefa33 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/net/AbstractNoiseHandlerTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import com.southernstorm.noise.protocol.CipherStatePair; +import com.southernstorm.noise.protocol.Noise; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; @@ -19,18 +20,22 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.util.ReferenceCountUtil; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; import javax.crypto.AEADBadTagException; import javax.crypto.BadPaddingException; import javax.crypto.ShortBufferException; -import io.netty.util.ReferenceCountUtil; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest { @@ -254,4 +259,29 @@ void writeUnexpectedMessageType() throws Throwable { assertTrue(embeddedChannel.outboundMessages().isEmpty()); } + @ParameterizedTest + @ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5}) + void writeHugeOutboundMessage(final int plaintextLength) throws Throwable { + final CipherStatePair clientCipherStatePair = doHandshake(); + final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength); + final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length)); + + final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer); + assertTrue(writePlaintextFuture.isSuccess()); + + final byte[] decryptedPlaintext = new byte[plaintextLength]; + int plaintextOffset = 0; + BinaryWebSocketFrame ciphertextFrame; + while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) { + assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN); + final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content()); + ciphertextFrame.release(); + plaintextOffset += clientCipherStatePair.getReceiver() + .decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length); + } + assertArrayEquals(plaintext, decryptedPlaintext); + assertEquals(0, plaintextBuffer.refCnt()); + + } + }