From 370ca786b96aa33c6c93800d8f8fc55ecc4c11bb Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Mon, 26 Aug 2024 17:03:42 +0300 Subject: [PATCH 1/2] client,server: configurable wire message size limits. Implement configurable limits for the maximum accepted message size of the wire protocol. The default limit can be overridden using the WithClientWireMessageLimit() option for clients and using the WithServerWireMessageLimit() option for servers. Add exported constants for the minimum, maximum and default limits. Signed-off-by: Krisztian Litkey --- channel.go | 61 ++++++++++++++++++++++++++++++++++++++---------------- client.go | 19 ++++++++++++----- config.go | 10 +++++++++ errors.go | 50 +++++++++++++++++++++++++++++++++++++++----- server.go | 21 ++++++++++++++++++- 5 files changed, 132 insertions(+), 29 deletions(-) diff --git a/channel.go b/channel.go index 872261e6d..50c917adc 100644 --- a/channel.go +++ b/channel.go @@ -23,14 +23,13 @@ import ( "io" "net" "sync" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( - messageHeaderLength = 10 - messageLengthMax = 4 << 20 + messageHeaderLength = 10 + MinMessageLengthLimit = 4 << 10 + MaxMessageLengthLimit = 4 << 22 + DefaultMessageLengthLimit = 4 << 20 ) type messageType uint8 @@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { var buffers sync.Pool type channel struct { - conn net.Conn - bw *bufio.Writer - br *bufio.Reader - hrbuf [messageHeaderLength]byte // avoid alloc when reading header - hwbuf [messageHeaderLength]byte + conn net.Conn + bw *bufio.Writer + br *bufio.Reader + hrbuf [messageHeaderLength]byte // avoid alloc when reading header + hwbuf [messageHeaderLength]byte + maxMsgLen int } -func newChannel(conn net.Conn) *channel { +func newChannel(conn net.Conn, maxMsgLen int) *channel { + if maxMsgLen == 0 { + maxMsgLen = DefaultMessageLengthLimit + } return &channel{ - conn: conn, - bw: bufio.NewWriter(conn), - br: bufio.NewReader(conn), + conn: conn, + bw: bufio.NewWriter(conn), + br: bufio.NewReader(conn), + maxMsgLen: maxMsgLen, } } @@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) { return messageHeader{}, nil, err } - if mh.Length > uint32(messageLengthMax) { + if maxMsgLen := ch.maxMsgLimit(true); mh.Length > uint32(maxMsgLen) { if _, err := ch.br.Discard(int(mh.Length)); err != nil { return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err) } - return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) + return mh, nil, OversizedMessageError(int(mh.Length), maxMsgLen) } var p []byte @@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - if len(p) > messageLengthMax { - return OversizedMessageError(len(p)) + if maxMsgLen := ch.maxMsgLimit(false); maxMsgLen != 0 { + if len(p) > maxMsgLen { + return OversizedMessageError(len(p), maxMsgLen) + } } if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { @@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte { func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) } + +func (ch *channel) maxMsgLimit(recv bool) int { + if ch.maxMsgLen == 0 && recv { + return DefaultMessageLengthLimit + } + return ch.maxMsgLen +} + +func clampWireMessageLimit(maxMsgLen int) int { + switch { + case maxMsgLen == 0: + return 0 + case maxMsgLen < MinMessageLengthLimit: + return MinMessageLengthLimit + case maxMsgLen > MaxMessageLengthLimit: + return MaxMessageLengthLimit + } + return maxMsgLen +} diff --git a/client.go b/client.go index b1bc7a3fc..09cfb2897 100644 --- a/client.go +++ b/client.go @@ -35,9 +35,10 @@ import ( // Client for a ttrpc server type Client struct { - codec codec - conn net.Conn - channel *channel + codec codec + conn net.Conn + channel *channel + maxMsgLen int streamLock sync.RWMutex streams map[streamID]*stream @@ -107,14 +108,20 @@ func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker } } +// WithClientWireMessageLimit sets the maximum allowed message length on the wire for the client. +func WithClientWireMessageLimit(maxMsgLen int) ClientOpts { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *Client) { + c.maxMsgLen = maxMsgLen + } +} + // NewClient creates a new ttrpc client using the given connection func NewClient(conn net.Conn, opts ...ClientOpts) *Client { ctx, cancel := context.WithCancel(context.Background()) - channel := newChannel(conn) c := &Client{ codec: codec{}, conn: conn, - channel: channel, streams: make(map[streamID]*stream), nextStreamID: 1, closed: cancel, @@ -127,6 +134,8 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client { o(c) } + c.channel = newChannel(conn, c.maxMsgLen) + if c.interceptor == nil { c.interceptor = defaultClientInterceptor } diff --git a/config.go b/config.go index f401f67be..5995b9a8b 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,7 @@ import ( type serverConfig struct { handshaker Handshaker interceptor UnaryServerInterceptor + maxMsgLen int } // ServerOpt for configuring a ttrpc server @@ -84,3 +85,12 @@ func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, intercep chainUnaryServerInterceptors(info, method, interceptors[1:])) } } + +// WithServerWireMessageLimit sets the maximum allowed message length on the wire for the server. +func WithServerWireMessageLimit(maxMsgLen int) ServerOpt { + maxMsgLen = clampWireMessageLimit(maxMsgLen) + return func(c *serverConfig) error { + c.maxMsgLen = maxMsgLen + return nil + } +} diff --git a/errors.go b/errors.go index 632dbe8bd..5f81bc491 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,7 @@ package ttrpc import ( "errors" + "fmt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -43,20 +44,59 @@ var ( // length. type OversizedMessageErr struct { messageLength int + maxLength int err error } +var ( + oversizedMsgFmt = "message length %d exceeds maximum message size of %d" + oversizedMsgScanFmt = fmt.Sprintf("%v", status.New(codes.ResourceExhausted, oversizedMsgFmt)) +) + // OversizedMessageError returns an OversizedMessageErr error for the given message // length if it exceeds the allowed maximum. Otherwise a nil error is returned. -func OversizedMessageError(messageLength int) error { - if messageLength <= messageLengthMax { +func OversizedMessageError(messageLength, maxLength int) error { + if messageLength <= maxLength { return nil } return &OversizedMessageErr{ messageLength: messageLength, - err: status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", messageLength, messageLengthMax), + maxLength: maxLength, + err: OversizedMessageStatus(messageLength, maxLength).Err(), + } +} + +// OversizedMessageStatus returns a Status for an oversized message error. +func OversizedMessageStatus(messageLength, maxLength int) *status.Status { + return status.Newf(codes.ResourceExhausted, oversizedMsgFmt, messageLength, maxLength) +} + +// OversizedMessageFromError reconstructs an OversizedMessageErr from a Status. +func OversizedMessageFromError(err error) (*OversizedMessageErr, bool) { + var ( + messageLength int + maxLength int + ) + + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false } + + // TODO(klihub): might be too ugly to recover an error this way... An + // alternative would be to define our custom status detail proto type, + // then use status.WithDetails() and status.Details(). + + n, _ := fmt.Sscanf(st.Message(), oversizedMsgScanFmt, &messageLength, &maxLength) + if n != 2 { + n, _ = fmt.Sscanf(st.Message(), oversizedMsgFmt, &messageLength, &maxLength) + } + if n != 2 { + return nil, false + } + + return OversizedMessageError(messageLength, maxLength).(*OversizedMessageErr), true } // Error returns the error message for the corresponding grpc Status for the error. @@ -75,6 +115,6 @@ func (e *OversizedMessageErr) RejectedLength() int { } // MaximumLength retrieves the maximum allowed message length that triggered the error. -func (*OversizedMessageErr) MaximumLength() int { - return messageLengthMax +func (e *OversizedMessageErr) MaximumLength() int { + return e.maxLength } diff --git a/server.go b/server.go index bb71de677..f30d5c924 100644 --- a/server.go +++ b/server.go @@ -339,7 +339,7 @@ func (c *serverConn) run(sctx context.Context) { ) var ( - ch = newChannel(c.conn) + ch = newChannel(c.conn, c.server.config.maxMsgLen) ctx, cancel = context.WithCancel(sctx) state connState = connStateIdle responses = make(chan response) @@ -373,6 +373,14 @@ func (c *serverConn) run(sctx context.Context) { } } + isResourceExhaustedError := func(err error) (*status.Status, bool) { + st, ok := status.FromError(err) + if !ok || st.Code() != codes.ResourceExhausted { + return nil, false + } + return st, true + } + go func(recvErr chan error) { defer close(recvErr) for { @@ -525,6 +533,17 @@ func (c *serverConn) run(sctx context.Context) { } if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil { + if st, ok := isResourceExhaustedError(err); ok { + p, err = c.server.codec.Marshal(&Response{ + Status: st.Proto(), + }) + if err != nil { + log.G(ctx).WithError(err).Error("failed marshaling error response") + return + } + ch.send(response.id, messageTypeResponse, 0, p) + return + } log.G(ctx).WithError(err).Error("failed sending message on channel") return } From 7a70e2e719e4856d87fb117a92e2d766632caf18 Mon Sep 17 00:00:00 2001 From: Krisztian Litkey Date: Mon, 26 Aug 2024 17:20:35 +0300 Subject: [PATCH 2/2] {channel,server}_test: add tests for message limits. Adjust unit test to accomodate for altered internal interfaces. Add unit tests to exercise the new message size limit options. Signed-off-by: Krisztian Litkey --- channel_test.go | 6 +- server_test.go | 273 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 247 insertions(+), 32 deletions(-) diff --git a/channel_test.go b/channel_test.go index 9eab63148..30263730d 100644 --- a/channel_test.go +++ b/channel_test.go @@ -31,8 +31,8 @@ import ( func TestReadWriteMessage(t *testing.T) { var ( w, r = net.Pipe() - ch = newChannel(w) - rch = newChannel(r) + ch = newChannel(w, 0) + rch = newChannel(r, 0) messages = [][]byte{ []byte("hello"), []byte("this is a test"), @@ -90,7 +90,7 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( w, _ = net.Pipe() - wch = newChannel(w) + wch = newChannel(w, 0) msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) errs = make(chan error, 1) ) diff --git a/server_test.go b/server_test.go index 4a1561df8..0e3186772 100644 --- a/server_test.go +++ b/server_test.go @@ -19,6 +19,7 @@ package ttrpc import ( "bytes" "context" + "crypto/md5" "errors" "fmt" "net" @@ -61,10 +62,17 @@ func (tc *testingClient) Test(ctx context.Context, req *internal.TestPayload) (* } // testingServer is what would be implemented by the user of this package. -type testingServer struct{} +type testingServer struct { + echoOnce bool +} func (s *testingServer) Test(ctx context.Context, req *internal.TestPayload) (*internal.TestPayload, error) { - tp := &internal.TestPayload{Foo: strings.Repeat(req.Foo, 2)} + tp := &internal.TestPayload{} + if s.echoOnce { + tp.Foo = req.Foo + } else { + tp.Foo = strings.Repeat(req.Foo, 2) + } if dl, ok := ctx.Deadline(); ok { tp.Deadline = dl.UnixNano() } @@ -329,38 +337,238 @@ func TestImmediateServerShutdown(t *testing.T) { } func TestOversizeCall(t *testing.T) { - var ( - ctx = context.Background() - server = mustServer(t)(NewServer()) - addr, listener = newTestListener(t) - errs = make(chan error, 1) - client, cleanup = newTestClient(t, addr) - ) - defer cleanup() - defer listener.Close() - go func() { - errs <- server.Serve(ctx, listener) - }() + type testCase struct { + name string + echoOnce bool + clientLimit int + serverLimit int + requestSize int + clientFail bool + sendFail bool + serverFail bool + } + + overhead := getWireMessageOverhead(t) + + clientOpts := func(tc *testCase) []ClientOpts { + if tc.clientLimit == 0 { + return nil + } + return []ClientOpts{WithClientWireMessageLimit(tc.clientLimit)} + } + serverOpts := func(tc *testCase) []ServerOpt { + if tc.serverLimit == 0 { + return nil + } + return []ServerOpt{WithServerWireMessageLimit(tc.serverLimit)} + } + + runTest := func(t *testing.T, tc *testCase) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer(serverOpts(tc)...)) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + client, cleanup = newTestClient(t, addr, clientOpts(tc)...) + ) + defer cleanup() + defer listener.Close() + go func() { + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{echoOnce: tc.echoOnce}) + + req := &internal.TestPayload{ + Foo: strings.Repeat("a", tc.requestSize), + } + rsp := &internal.TestPayload{} + + err := client.Call(ctx, serviceName, "Test", req, rsp) + if tc.clientFail { + if err == nil { + t.Fatalf("expected error from oversized message") + } else if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if tc.sendFail { + var msgLenErr *OversizedMessageErr + if !errors.As(err, &msgLenErr) { + t.Fatalf("failed to retrieve client send OversizedMessageErr") + } + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in client send oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in client send oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("client send oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } else if tc.serverFail { + if err == nil { + t.Fatalf("expected error from server-side oversized message") + } else { + if status, ok := status.FromError(err); !ok { + t.Fatalf("expected status present in error: %v", err) + } else if status.Code() != codes.ResourceExhausted { + t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) + } + if msgLenErr, ok := OversizedMessageFromError(err); !ok { + t.Fatalf("failed to retrieve oversized message error") + } else { + rejLen, maxLen := msgLenErr.RejectedLength(), msgLenErr.MaximumLength() + if rejLen == 0 { + t.Fatalf("zero rejected length in oversized message error") + } + if maxLen == 0 { + t.Fatalf("zero maximum length in oversized message error") + } + if rejLen <= maxLen { + t.Fatalf("oversized message error rejected < max. length (%d < %d)", + rejLen, maxLen) + } + } + } + } else { + if err != nil { + t.Fatalf("expected success, got error %v", err) + } + } - registerTestingService(server, &testingServer{}) + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + if err := <-errs; err != ErrServerClosed { + t.Fatal(err) + } + } - tp := &internal.TestPayload{ - Foo: strings.Repeat("a", 1+messageLengthMax), + for _, tc := range []*testCase{ + { + name: "default limits, fitting request and response", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + }, + { + name: "default limits, only recv side check", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit - overhead, + serverFail: true, + }, + + { + name: "default limits, oversized request", + echoOnce: true, + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit, + clientFail: true, + }, + { + name: "default limits, oversized response", + clientLimit: 0, + serverLimit: 0, + requestSize: DefaultMessageLengthLimit / 2, + serverFail: true, + }, + { + name: "8K limits, 4K request and response", + echoOnce: true, + clientLimit: 8 * 1024, + serverLimit: 8 * 1024, + requestSize: 4 * 1024, + }, + { + name: "4K limits, barely fitting cc. 4K request and response", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4*1024 - overhead, + }, + { + name: "4K limits, oversized request on client side", + echoOnce: true, + clientLimit: 4 * 1024, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + clientFail: true, + sendFail: true, + }, + { + name: "4K limits, oversized request on server side", + echoOnce: true, + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "4K limits, oversized response on client side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 8*1024 + overhead, + clientFail: true, + }, + { + name: "4K limits, oversized response on server side", + clientLimit: 4*1024 + overhead, + serverLimit: 4 * 1024, + requestSize: 4 * 1024, + serverFail: true, + }, + { + name: "too small limits, adjusted to minimum accepted limit", + echoOnce: true, + clientLimit: 4, + serverLimit: 4, + requestSize: 4*1024 - overhead, + }, + { + name: "maximum allowed protocol limit", + echoOnce: true, + clientLimit: MaxMessageLengthLimit, + serverLimit: MaxMessageLengthLimit, + requestSize: MaxMessageLengthLimit - overhead, + }, + } { + t.Run(tc.name, func(t *testing.T) { + runTest(t, tc) + }) } - if err := client.Call(ctx, serviceName, "Test", tp, tp); err == nil { - t.Fatalf("expected error from oversized message") - } else if status, ok := status.FromError(err); !ok { - t.Fatalf("expected status present in error: %v", err) - } else if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected code: %v != %v", status.Code(), codes.ResourceExhausted) +} + +func getWireMessageOverhead(t *testing.T) int { + emptyReq, err := codec{}.Marshal(&Request{ + Service: serviceName, + Method: "Test", + }) + if err != nil { + t.Fatalf("failed to marshal empty request: %v", err) } - if err := server.Shutdown(ctx); err != nil { - t.Fatal(err) + emptyRsp, err := codec{}.Marshal(&Response{ + Status: status.New(codes.OK, "").Proto(), + }) + if err != nil { + t.Fatalf("failed to marshal empty response: %v", err) } - if err := <-errs; err != ErrServerClosed { - t.Fatal(err) + + reqLen := len(emptyReq) + rspLen := len(emptyRsp) + if reqLen > rspLen { + return reqLen + messageHeaderLength } + + return rspLen + messageHeaderLength } func TestClientEOF(t *testing.T) { @@ -581,13 +789,20 @@ func newTestClient(t testing.TB, addr string, opts ...ClientOpts) (*Client, func } func newTestListener(t testing.TB) (string, net.Listener) { - var prefix string + var ( + name = t.Name() + prefix string + ) // Abstracts sockets are only available on Linux. if runtime.GOOS == "linux" { prefix = "\x00" + } else { + if split := strings.SplitN(name, "/", 2); len(split) == 2 { + name = split[0] + "-" + fmt.Sprintf("%x", md5.Sum([]byte(split[1]))) + } } - addr := prefix + t.Name() + addr := prefix + name listener, err := net.Listen("unix", addr) if err != nil { t.Fatal(err)