Skip to content

Commit

Permalink
Break up large outbound noise messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi-signal committed Jul 30, 2024
1 parent 542422b commit 3d96d73
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.southernstorm.noise.protocol.CipherState;
import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
Expand All @@ -16,6 +17,7 @@
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.PromiseCombiner;
import io.netty.util.internal.EmptyArrays;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -163,25 +165,35 @@ private void fail(final ChannelHandlerContext context, final Throwable cause) {
@Override
public void write(final ChannelHandlerContext context, final Object message, final ChannelPromise promise)
throws Exception {
if (message instanceof ByteBuf plaintext) {
if (message instanceof ByteBuf byteBuf) {
try {
// TODO Buffer/consolidate Noise writes to avoid sending a bazillion tiny (or empty) frames
final CipherState cipherState = cipherStatePair.getSender();
final int plaintextLength = plaintext.readableBytes();

// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());
// Server message might not fit in a single noise packet, break it up into as many chunks as we need
final PromiseCombiner pc = new PromiseCombiner(context.executor());
while (byteBuf.isReadable()) {
final ByteBuf plaintext = byteBuf.readSlice(Math.min(
// need room for a 16-byte AEAD tag
Noise.MAX_PACKET_LEN - 16,
byteBuf.readableBytes()));

// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);
final int plaintextLength = plaintext.readableBytes();

context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer)), promise);
// We've read these bytes from a local connection; although that likely means they're backed by a heap array, the
// buffer is read-only and won't grant us access to the underlying array. Instead, we need to copy the bytes to a
// mutable array. We also want to encrypt in place, so we allocate enough extra space for the trailing MAC.
final byte[] noiseBuffer = new byte[plaintext.readableBytes() + cipherState.getMACLength()];
plaintext.readBytes(noiseBuffer, 0, plaintext.readableBytes());

// Overwrite the plaintext with the ciphertext to avoid an extra allocation for a dedicated ciphertext buffer
cipherState.encryptWithAd(null, noiseBuffer, 0, noiseBuffer, 0, plaintextLength);

pc.add(context.write(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(noiseBuffer))));
}
pc.finish(promise);
} finally {
ReferenceCountUtil.release(plaintext);
ReferenceCountUtil.release(byteBuf);
}
} else {
if (!(message instanceof WebSocketFrame)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.southernstorm.noise.protocol.CipherStatePair;
import com.southernstorm.noise.protocol.Noise;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
Expand All @@ -19,18 +20,22 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.util.ReferenceCountUtil;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import javax.annotation.Nullable;
import javax.crypto.AEADBadTagException;
import javax.crypto.BadPaddingException;
import javax.crypto.ShortBufferException;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;

abstract class AbstractNoiseHandlerTest extends AbstractLeakDetectionTest {

Expand Down Expand Up @@ -254,4 +259,29 @@ void writeUnexpectedMessageType() throws Throwable {
assertTrue(embeddedChannel.outboundMessages().isEmpty());
}

@ParameterizedTest
@ValueSource(ints = {Noise.MAX_PACKET_LEN - 16, Noise.MAX_PACKET_LEN - 15, Noise.MAX_PACKET_LEN * 5})
void writeHugeOutboundMessage(final int plaintextLength) throws Throwable {
final CipherStatePair clientCipherStatePair = doHandshake();
final byte[] plaintext = TestRandomUtil.nextBytes(plaintextLength);
final ByteBuf plaintextBuffer = Unpooled.wrappedBuffer(Arrays.copyOf(plaintext, plaintext.length));

final ChannelFuture writePlaintextFuture = embeddedChannel.pipeline().writeAndFlush(plaintextBuffer);
assertTrue(writePlaintextFuture.isSuccess());

final byte[] decryptedPlaintext = new byte[plaintextLength];
int plaintextOffset = 0;
BinaryWebSocketFrame ciphertextFrame;
while ((ciphertextFrame = (BinaryWebSocketFrame) embeddedChannel.outboundMessages().poll()) != null) {
assertTrue(ciphertextFrame.content().readableBytes() <= Noise.MAX_PACKET_LEN);
final byte[] ciphertext = ByteBufUtil.getBytes(ciphertextFrame.content());
ciphertextFrame.release();
plaintextOffset += clientCipherStatePair.getReceiver()
.decryptWithAd(null, ciphertext, 0, decryptedPlaintext, plaintextOffset, ciphertext.length);
}
assertArrayEquals(plaintext, decryptedPlaintext);
assertEquals(0, plaintextBuffer.refCnt());

}

}

0 comments on commit 3d96d73

Please sign in to comment.