Skip to content

credentials/alts: Optimize Reads (Roll forward #8236) #8271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions credentials/alts/internal/conn/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
if len(b) < MsgLenFieldSize {
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
Expand All @@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

// parseMessageLength returns the message length based on frame header. It also
// returns a boolean indicating if the buffer contains sufficient bytes to parse
// the length header. If there are insufficient bytes, (0, false) is returned.
func parseMessageLength(b []byte) (uint32, bool) {
if len(b) < MsgLenFieldSize {
return 0, false
}
msgLenField := b[:MsgLenFieldSize]
return binary.LittleEndian.Uint32(msgLenField), true
}
58 changes: 38 additions & 20 deletions credentials/alts/internal/conn/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ const (
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
// The initial buffer used to read from the network.
altsReadBufferInitialSize = 32 * 1024 // 32KiB
)

var (
Expand All @@ -87,7 +89,7 @@ type conn struct {
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read.
// but has not yet been returned by Read. It is a sub-slice of protected.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
Expand Down Expand Up @@ -115,21 +117,13 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
var protectedBuf []byte
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
} else {
protectedBuf = make([]byte, len(protected))
copy(protectedBuf, protected)
}
// We pre-allocate protected to be of size 32KB during initialization.
// We increase the size of the buffer by the required amount if it can't
// hold a complete encrypted record.
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
// Copy additional data from hanshaker service.
copy(protectedBuf, protected)
protectedBuf = protectedBuf[:len(protected)]

altsConn := &conn{
Conn: c,
Expand Down Expand Up @@ -166,11 +160,26 @@ func (p *conn) Read(b []byte) (n int, err error) {
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
// We can parse the length header to know exactly how large
// the buffer needs to be to hold the entire frame.
length, didParse := parseMessageLength(p.protected)
if !didParse {
// The protected buffer is initialized with a capacity of
// larger than 4B. It should always hold the message length
// header.
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
}
oldProtectedBuf := p.protected
// The new buffer must be able to hold the message length header
// and the entire message.
requiredCapacity := int(length) + MsgLenFieldSize
p.protected = make([]byte, requiredCapacity)
// Copy the contents of the old buffer and set the length of the
// new buffer to the number of bytes already read.
copy(p.protected, oldProtectedBuf)
p.protected = p.protected[:len(oldProtectedBuf)]
}
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
if err != nil {
return 0, err
}
Expand All @@ -189,6 +198,15 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
ciphertext := msg[msgTypeFieldSize:]

// Decrypt directly into the buffer, avoiding a copy from p.buf if
// possible.
if len(b) >= len(ciphertext) {
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
if err != nil {
return 0, err
}
return len(dec), nil
}
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
Expand Down
43 changes: 43 additions & 0 deletions credentials/alts/internal/conn/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"math"
"net"
"reflect"
"strings"
"testing"

core "google.golang.org/grpc/credentials/alts/internal"
Expand Down Expand Up @@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) {
}
}

// TestLargeRecord writes a very large ALTS record and verifies that the server
// receives it correctly. The large ALTS record should cause the reader to
// expand it's read buffer to hold the entire record and store the decrypted
// message until the receiver reads all of the bytes.
func (s) TestLargeRecord(t *testing.T) {
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
// Increase the size of ALTS records written by the client.
clientConn.payloadLengthLimit = math.MaxInt32
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}

// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
// receiving a large message.
func BenchmarkLargeMessage(b *testing.B) {
msgLen := 20 * 1024 * 1024 // 20 MiB
msg := make([]byte, msgLen)
rcvMsg := make([]byte, len(msg))
b.ResetTimer()
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
for range b.N {
// Write 20 MiB 5 times to transfer a total of 100 MiB.
for range 5 {
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
}
}
}

func testIncorrectMsgType(t *testing.T, rp string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.
Expand Down
2 changes: 1 addition & 1 deletion credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
var lastWriteTime time.Time
buf := make([]byte, frameLimit)
for {
if len(resp.OutFrames) > 0 {
lastWriteTime = time.Now()
Expand All @@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err
Expand Down