Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce discard limits on readers #790

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,13 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM
}
return errorf(CodeInvalidArgument, "decompress: %w", err)
}
if readMaxBytes > 0 && bytesRead > readMaxBytes {
discardedBytes, err := io.Copy(io.Discard, decompressor)
_ = c.putDecompressor(decompressor)
if err != nil {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes)
}
if err := c.putDecompressor(decompressor); err != nil {
return errorf(CodeUnknown, "recycle decompressor: %w", err)
}
if readMaxBytes > 0 && bytesRead > readMaxBytes {
// Resource is exhausted, fail fast without reading more data from the reader.
return errorf(CodeResourceExhausted, "decompressed message size is larger than configured max %d", readMaxBytes)
}
return nil
}

Expand Down
7 changes: 1 addition & 6 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,6 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
})
t.Run("read_max_large", func(t *testing.T) {
t.Parallel()
Expand All @@ -1206,16 +1205,14 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
}
// Serializes to much larger than readMaxBytes (5 MiB)
pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)}
expectedSize := proto.Size(pingRequest)
// With gzip request compression, the error should indicate the envelope size (before decompression) is too large.
if compressed {
expectedSize = gzipCompressedSize(t, pingRequest)
expectedSize := gzipCompressedSize(t, pingRequest)
assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes))
}
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
})
}
newHTTP2Server := func(t *testing.T) *memhttp.Server {
Expand Down Expand Up @@ -1378,7 +1375,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
})
t.Run("read_max_large", func(t *testing.T) {
t.Parallel()
Expand All @@ -1397,7 +1393,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
})
}
t.Run("connect", func(t *testing.T) {
Expand Down
33 changes: 14 additions & 19 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ type envelopeReader struct {
compressionPool *compressionPool
bufferPool *bufferPool
readMaxBytes int
isEOF bool
}

func (r *envelopeReader) Unmarshal(message any) *Error {
if r.isEOF {
return NewError(CodeInternal, io.EOF)
}
buffer := r.bufferPool.Get()
var dontRelease *bytes.Buffer
defer func() {
Expand All @@ -240,25 +244,20 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
}()

env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
if err := r.Read(env); err != nil {
// Mark the reader as EOF so that subsequent reads return EOF.
r.isEOF = true
return err
}
if env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil {
return errorf(
CodeInternal,
"protocol error: sent compressed message without compression support",
)
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
} else if (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && env.Data.Len() == 0 {
// This is a standard message (because none of the top 7 bits are set) and
// there's no data, so the zero value of the message is correct.
return nil
case err != nil && errors.Is(err, io.EOF):
// The stream has ended. Propagate the EOF to the caller.
return err
case err != nil:
// Something's wrong.
return err
}

data := env.Data
Expand Down Expand Up @@ -317,7 +316,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
return NewError(CodeInternal, err)
}
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
err = wrapIfContextDone(r.ctx, err)
Expand All @@ -332,12 +331,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
n, err := io.CopyN(io.Discard, r.reader, size)
r.bytesRead += n
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes)
// Resource is exhausted, fail fast without reading more data from the stream.
return errorf(CodeResourceExhausted, "received message size %d is larger than configured max %d", size, r.readMaxBytes)
}
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
Expand Down
8 changes: 4 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,12 @@ func isCommaOrSpace(c rune) bool {
}

func discard(reader io.Reader) (int64, error) {
if lr, ok := reader.(*io.LimitedReader); ok {
return io.Copy(io.Discard, lr)
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
lr, ok := reader.(*io.LimitedReader)
if !ok {
lr = &io.LimitedReader{R: reader, N: discardLimit}
}
return io.Copy(io.Discard, lr)
}

Expand Down
14 changes: 5 additions & 9 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1088,19 +1088,19 @@ type connectUnaryUnmarshaler struct {
codec Codec
compressionPool *compressionPool
bufferPool *bufferPool
alreadyRead bool
readMaxBytes int
isEOF bool
}

func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error {
return u.UnmarshalFunc(message, u.codec.Unmarshal)
}

func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error {
if u.alreadyRead {
if u.isEOF {
return NewError(CodeInternal, io.EOF)
}
u.alreadyRead = true
u.isEOF = true
data := u.bufferPool.Get()
defer u.bufferPool.Put(data)
reader := u.reader
Expand All @@ -1118,12 +1118,8 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
// Attempt to read to end in order to allow connection re-use
discardedBytes, err := io.Copy(io.Discard, u.reader)
if err != nil {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes)
// Resource is exhausted, fail fast without reading more data from the stream.
return errorf(CodeResourceExhausted, "message size is larger than configured max %d", u.readMaxBytes)
}
if data.Len() > 0 && u.compressionPool != nil {
decompressed := u.bufferPool.Get()
Expand Down
4 changes: 2 additions & 2 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ func (g *grpcClient) NewConn(
}
} else {
conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header {
// To access HTTP trailers, we need to read the body to EOF.
_, _ = discard(call)
// Caller must guarantee the body is read to EOF to access
// trailers.
return call.ResponseTrailer()
}
}
Expand Down