diff --git a/compression.go b/compression.go index ee43412b..0a1db568 100644 --- a/compression.go +++ b/compression.go @@ -96,17 +96,13 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM } return errorf(CodeInvalidArgument, "decompress: %w", err) } - if readMaxBytes > 0 && bytesRead > readMaxBytes { - discardedBytes, err := io.Copy(io.Discard, decompressor) - _ = c.putDecompressor(decompressor) - if err != nil { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes) - } if err := c.putDecompressor(decompressor); err != nil { return errorf(CodeUnknown, "recycle decompressor: %w", err) } + if readMaxBytes > 0 && bytesRead > readMaxBytes { + // Resource is exhausted, fail fast without reading more data from the reader. + return errorf(CodeResourceExhausted, "decompressed message size is larger than configured max %d", readMaxBytes) + } return nil } diff --git a/connect_ext_test.go b/connect_ext_test.go index b93c5708..a783e85d 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -1197,7 +1197,6 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes))) }) t.Run("read_max_large", func(t *testing.T) { t.Parallel() @@ -1206,16 +1205,14 @@ func TestHandlerWithReadMaxBytes(t *testing.T) { } // Serializes to much larger than readMaxBytes (5 MiB) pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)} - expectedSize := proto.Size(pingRequest) // With gzip request compression, the error should indicate the envelope size (before decompression) is too large. if compressed { - expectedSize = gzipCompressedSize(t, pingRequest) + expectedSize := gzipCompressedSize(t, pingRequest) assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes)) } _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } newHTTP2Server := func(t *testing.T) *memhttp.Server { @@ -1378,7 +1375,6 @@ func TestClientWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes))) }) t.Run("read_max_large", func(t *testing.T) { t.Parallel() @@ -1397,7 +1393,6 @@ func TestClientWithReadMaxBytes(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(pingRequest)) assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message")) assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted) - assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes)) }) } t.Run("connect", func(t *testing.T) { diff --git a/envelope.go b/envelope.go index bc85c551..ec296dee 100644 --- a/envelope.go +++ b/envelope.go @@ -228,9 +228,13 @@ type envelopeReader struct { compressionPool *compressionPool bufferPool *bufferPool readMaxBytes int + isEOF bool } func (r *envelopeReader) Unmarshal(message any) *Error { + if r.isEOF { + return NewError(CodeInternal, io.EOF) + } buffer := r.bufferPool.Get() var dontRelease *bytes.Buffer defer func() { @@ -240,25 +244,20 @@ func (r *envelopeReader) Unmarshal(message any) *Error { }() env := &envelope{Data: buffer} - err := r.Read(env) - switch { - case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil: + if err := r.Read(env); err != nil { + // Mark the reader as EOF so that subsequent reads return EOF. + r.isEOF = true + return err + } + if env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil { return errorf( CodeInternal, "protocol error: sent compressed message without compression support", ) - case err == nil && - (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && - env.Data.Len() == 0: + } else if (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && env.Data.Len() == 0 { // This is a standard message (because none of the top 7 bits are set) and // there's no data, so the zero value of the message is correct. return nil - case err != nil && errors.Is(err, io.EOF): - // The stream has ended. Propagate the EOF to the caller. - return err - case err != nil: - // Something's wrong. - return err } data := env.Data @@ -317,7 +316,7 @@ func (r *envelopeReader) Read(env *envelope) *Error { // The stream ended cleanly. That's expected, but we need to propagate an EOF // to the user so that they know that the stream has ended. We shouldn't // add any alarming text about protocol errors, though. - return NewError(CodeUnknown, err) + return NewError(CodeInternal, err) } err = wrapIfMaxBytesError(err, "read 5 byte message prefix") err = wrapIfContextDone(r.ctx, err) @@ -332,12 +331,8 @@ func (r *envelopeReader) Read(env *envelope) *Error { } size := int64(binary.BigEndian.Uint32(prefixes[1:5])) if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) { - n, err := io.CopyN(io.Discard, r.reader, size) - r.bytesRead += n - if err != nil && !errors.Is(err, io.EOF) { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes) + // Resource is exhausted, fail fast without reading more data from the stream. + return errorf(CodeResourceExhausted, "received message size %d is larger than configured max %d", size, r.readMaxBytes) } // We've read the prefix, so we know how many bytes to expect. // CopyN will return an error if it doesn't read the requested diff --git a/protocol.go b/protocol.go index 9add614c..dc8e3d06 100644 --- a/protocol.go +++ b/protocol.go @@ -287,12 +287,12 @@ func isCommaOrSpace(c rune) bool { } func discard(reader io.Reader) (int64, error) { - if lr, ok := reader.(*io.LimitedReader); ok { - return io.Copy(io.Discard, lr) - } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. - lr := &io.LimitedReader{R: reader, N: discardLimit} + lr, ok := reader.(*io.LimitedReader) + if !ok { + lr = &io.LimitedReader{R: reader, N: discardLimit} + } return io.Copy(io.Discard, lr) } diff --git a/protocol_connect.go b/protocol_connect.go index 6828ab4d..5f7cec94 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1088,8 +1088,8 @@ type connectUnaryUnmarshaler struct { codec Codec compressionPool *compressionPool bufferPool *bufferPool - alreadyRead bool readMaxBytes int + isEOF bool } func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { @@ -1097,10 +1097,10 @@ func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { } func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error { - if u.alreadyRead { + if u.isEOF { return NewError(CodeInternal, io.EOF) } - u.alreadyRead = true + u.isEOF = true data := u.bufferPool.Get() defer u.bufferPool.Put(data) reader := u.reader @@ -1118,12 +1118,8 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by return errorf(CodeUnknown, "read message: %w", err) } if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { - // Attempt to read to end in order to allow connection re-use - discardedBytes, err := io.Copy(io.Discard, u.reader) - if err != nil { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes) + // Resource is exhausted, fail fast without reading more data from the stream. + return errorf(CodeResourceExhausted, "message size is larger than configured max %d", u.readMaxBytes) } if data.Len() > 0 && u.compressionPool != nil { decompressed := u.bufferPool.Get() diff --git a/protocol_grpc.go b/protocol_grpc.go index e10ecad7..db1890e8 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -319,8 +319,8 @@ func (g *grpcClient) NewConn( } } else { conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { - // To access HTTP trailers, we need to read the body to EOF. - _, _ = discard(call) + // Caller must guarantee the body is read to EOF to access + // trailers. return call.ResponseTrailer() } }