Skip to content

Commit

Permalink
add flag to identify status triggered by upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
DMwangnima committed Sep 16, 2024
1 parent f6df547 commit ed1b013
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 18 deletions.
5 changes: 4 additions & 1 deletion pkg/remote/trans/nphttp2/grpc/controlbuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ func (c *controlBuffer) get(block bool) (interface{}, error) {
case <-c.ch:
case <-c.done:
c.finish(errConnClosing)
return nil, errConnClosing
c.mu.Lock()
err := c.err
c.mu.Unlock()
return nil, err
}
}
}
Expand Down
20 changes: 17 additions & 3 deletions pkg/remote/trans/nphttp2/grpc/controlbuf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ package grpc

import (
"context"
"errors"
"testing"
"time"

"github.com/cloudwego/kitex/internal/test"
)

func TestControlBuf(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
cb := newControlBuffer(ctx.Done())

// test put()
Expand Down Expand Up @@ -70,6 +71,19 @@ func TestControlBuf(t *testing.T) {

cb.throttle()

// test finish()
cb.finish(errConnClosing)
finishErr := errors.New("finish")
go func() {
// make sure cb.get blocks
time.Sleep(time.Millisecond * 100)
cb.finish(finishErr)
cancel()
}()
item, err = cb.get(true)
test.Assert(t, err == finishErr, err)
test.Assert(t, item == nil, item)

err = cb.put(testItem)
test.Assert(t, err == finishErr, err)
_, err = cb.get(false)
test.Assert(t, err == finishErr, err)
}
15 changes: 7 additions & 8 deletions pkg/remote/trans/nphttp2/grpc/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,12 @@ var (

// errors used for cancelling stream.
// the code should be codes.Canceled coz it's NOT returned from remote
errConnectionEOF = status.Err(codes.Canceled, "transport: connection EOF")
errConnectionEOF = status.ErrWithTriggeredByUpstream(codes.Canceled, "transport: connection EOF")
errMaxStreamsExceeded = status.Err(codes.Canceled, "transport: max streams exceeded")
errNotReachable = status.Err(codes.Canceled, "transport: server not reachable")
errMaxAgeClosing = status.Err(codes.Canceled, "transport: closing server transport due to maximum connection age")
errIdleClosing = status.Err(codes.Canceled, "transport: closing server transport due to idleness")
errMaxAgeClosing = status.ErrWithTriggeredByUpstream(codes.Canceled, "transport: closing server transport due to maximum connection age")
errIdleClosing = status.ErrWithTriggeredByUpstream(codes.Canceled, "transport: closing server transport due to idleness")

errRSTStreamRecv = status.Err(codes.Canceled, "transport: RSTStream Frame received")
errHeaderListSizeLimitViolation = status.Err(codes.Internal, ErrHeaderListSizeLimitViolation.Error())
errIllegalHeaderWrite = status.Err(codes.Internal, ErrIllegalHeaderWrite.Error())
)
Expand Down Expand Up @@ -409,7 +408,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
s := t.activeStreams[se.StreamID]
t.mu.Unlock()
if s != nil {
t.closeStream(s, status.Errorf(codes.Canceled, "transport: ReadFrame encountered http2.StreamError: %v", err), true, se.Code, false)
t.closeStream(s, status.ErrorfWithTriggeredByUpstream(codes.Canceled, "transport: ReadFrame encountered http2.StreamError: %v", err), true, se.Code, false)
} else {
t.controlBuf.put(&cleanupStream{
streamID: se.StreamID,
Expand All @@ -426,7 +425,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
return
}
klog.CtxWarnf(t.ctx, "transport: http2Server.HandleStreams failed to read frame: %v", err)
t.closeWithErr(err)
t.closeWithErr(status.ErrorfWithTriggeredByUpstream(codes.Canceled, "transport: ReadFrame encountered err: %v", err))
return
}
switch frame := frame.(type) {
Expand Down Expand Up @@ -553,7 +552,7 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) {
}
if size > 0 {
if err := s.fc.onData(size); err != nil {
t.closeStream(s, status.Errorf(codes.Canceled, "transport: inflow control err: %v", err), true, http2.ErrCodeFlowControl, false)
t.closeStream(s, status.ErrorfWithTriggeredByUpstream(codes.Canceled, "transport: inflow control err: %v", err), true, http2.ErrCodeFlowControl, false)
return
}
if f.Header().Flags.Has(http2.FlagDataPadded) {
Expand Down Expand Up @@ -581,7 +580,7 @@ func (t *http2Server) handleData(f *grpcframe.DataFrame) {
func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) {
// If the stream is not deleted from the transport's active streams map, then do a regular close stream.
if s, ok := t.getStream(f); ok {
t.closeStream(s, errRSTStreamRecv, false, 0, false)
t.closeStream(s, status.ErrorfWithTriggeredByUpstream(codes.Canceled, "transport: RSTStream Frame received with error code: %v", f.ErrCode), false, 0, false)
return
}
// If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map.
Expand Down
4 changes: 2 additions & 2 deletions pkg/remote/trans/nphttp2/grpc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ func (h *testStreamHandler) handleRstStreamFrame(t *testing.T, s *Stream) {
test.Assert(t, err == nil, err)
h.svr.waitClientFinished()
err = h.t.Write(s, nil, nil, nil)
test.Assert(t, err == errRSTStreamRecv, err)
test.Assert(t, strings.Contains(err.Error(), "transport: RSTStream Frame received"), err)
_, err = s.Read(readTestMsg)
test.Assert(t, err == errRSTStreamRecv, err)
test.Assert(t, strings.Contains(err.Error(), "transport: RSTStream Frame received"), err)
}

func (h *testStreamHandler) handleStreams(t *testing.T, s *Stream) {
Expand Down
39 changes: 37 additions & 2 deletions pkg/remote/trans/nphttp2/status/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,34 @@ type Iface interface {
// and should be created with New, Newf, or FromProto.
type Status struct {
s *spb.Status
// triggeredByUpstream identifies whether this error was triggered by upstream
triggeredByUpstream bool
}

// New returns a Status representing c and msg.
func New(c codes.Code, msg string) *Status {
return &Status{s: &spb.Status{Code: int32(c), Message: msg}}
}

// NewWithTriggeredByUpstream returns a Status identified triggered by upstream
func NewWithTriggeredByUpstream(c codes.Code, msg string) *Status {
res := New(c, msg)
res.triggeredByUpstream = true
return res
}

// Newf returns New(c, fmt.Sprintf(format, a...)).
func Newf(c codes.Code, format string, a ...interface{}) *Status {
return New(c, fmt.Sprintf(format, a...))
}

// NewfWithTriggeredByUpstream is the same as Newf with identifying triggered by upstream
func NewfWithTriggeredByUpstream(c codes.Code, format string, a ...interface{}) *Status {
res := Newf(c, format, a...)
res.triggeredByUpstream = true
return res
}

// ErrorProto returns an error representing s. If s.Code is OK, returns nil.
func ErrorProto(s *spb.Status) error {
return FromProto(s).Err()
Expand All @@ -76,11 +92,21 @@ func Err(c codes.Code, msg string) error {
return New(c, msg).Err()
}

// ErrWithTriggeredByUpstream is the same as Err with identifying triggered by upstream
func ErrWithTriggeredByUpstream(c codes.Code, msg string) error {
return NewWithTriggeredByUpstream(c, msg).Err()
}

// Errorf returns Error(c, fmt.Sprintf(format, a...)).
func Errorf(c codes.Code, format string, a ...interface{}) error {
return Err(c, fmt.Sprintf(format, a...))
}

// ErrorfWithTriggeredByUpstream is the same as Errorf with identifying triggered by upstream
func ErrorfWithTriggeredByUpstream(c codes.Code, format string, a ...interface{}) error {
return NewfWithTriggeredByUpstream(c, format, a...).Err()
}

// Code returns the status code contained in s.
func (s *Status) Code() codes.Code {
if s == nil || s.s == nil {
Expand Down Expand Up @@ -119,7 +145,7 @@ func (s *Status) Err() error {
if s.Code() == codes.OK {
return nil
}
return &Error{e: s.Proto()}
return &Error{e: s.Proto(), triggeredByUpstream: s.triggeredByUpstream}
}

// WithDetails returns a new status with the provided details messages appended to the status.
Expand Down Expand Up @@ -158,13 +184,22 @@ func (s *Status) Details() []interface{} {
return details
}

// TriggeredByUpstream returns whether the error was triggered by upstream.
func (s *Status) TriggeredByUpstream() bool {
return s.triggeredByUpstream
}

// Error wraps a pointer of a status proto. It implements error and Status,
// and a nil *Error should never be returned by this package.
type Error struct {
e *spb.Status
e *spb.Status
triggeredByUpstream bool
}

func (e *Error) Error() string {
if e.triggeredByUpstream {
return fmt.Sprintf("upstream error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage())
}
return fmt.Sprintf("rpc error: code = %d desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage())
}

Expand Down
31 changes: 29 additions & 2 deletions pkg/remote/trans/nphttp2/status/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package status
import (
"context"
"fmt"
"strings"
"testing"

spb "google.golang.org/genproto/googleapis/rpc/status"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestError(t *testing.T) {
s.Code = 1
s.Message = "test err"

er := &Error{s}
er := &Error{s, false}
test.Assert(t, len(er.Error()) > 0)

status := er.GRPCStatus()
Expand Down Expand Up @@ -101,7 +102,7 @@ func TestFromContextError(t *testing.T) {
s := new(spb.Status)
s.Code = 1
s.Message = "test err"
grpcErr := &Error{s}
grpcErr := &Error{s, false}
// grpc err
codeGrpcErr := Code(grpcErr)
test.Assert(t, codeGrpcErr == codes.Canceled)
Expand All @@ -114,3 +115,29 @@ func TestFromContextError(t *testing.T) {
codeNil := Code(nil)
test.Assert(t, codeNil == codes.OK)
}

func TestStatusWithTriggeredByUpstream(t *testing.T) {
statusMsg := "test"
statusOK := NewWithTriggeredByUpstream(codes.OK, statusMsg)
test.Assert(t, statusOK.Code() == codes.OK)
test.Assert(t, statusOK.Message() == statusMsg)
test.Assert(t, statusOK.triggeredByUpstream)
statusErr := statusOK.Err()
test.Assert(t, statusErr == nil)

statusCanceled := NewfWithTriggeredByUpstream(codes.Canceled, "%s", statusMsg)
test.Assert(t, statusCanceled.Code() == codes.Canceled)
test.Assert(t, statusCanceled.Message() == statusMsg)
test.Assert(t, statusCanceled.triggeredByUpstream)
statusErr = statusCanceled.Err()
test.Assert(t, strings.Contains(statusErr.Error(), "upstream error"))
}

func TestErrorWithTriggeredByUpstream(t *testing.T) {
statusMsg := "test"
errOK := ErrWithTriggeredByUpstream(codes.OK, statusMsg)
test.Assert(t, errOK == nil)

errCanceled := ErrorfWithTriggeredByUpstream(codes.Canceled, "%s", statusMsg)
test.Assert(t, strings.Contains(errCanceled.Error(), "upstream error"))
}

0 comments on commit ed1b013

Please sign in to comment.