Skip to content

Commit

Permalink
Merge pull request #1217 from ydb-platform/trace-stream-finish
Browse files Browse the repository at this point in the history
Trace stream finish
  • Loading branch information
asmyasnikov authored May 2, 2024
2 parents d6ceff7 + 36141fb commit 8962b78
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 101 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
* Added `trace.DriverConnStreamEvents` details bit
* Added `trace.Driver.OnConnStreamFinish` event

## v3.66.1
* Added flush messages from buffer before close topic writer
* Added Flush method for topic writer
Expand Down
55 changes: 25 additions & 30 deletions internal/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Conn interface {
type conn struct {
mtx sync.RWMutex
config Config // ro access
cc *grpc.ClientConn
grpcConn *grpc.ClientConn
done chan struct{}
endpoint endpoint.Endpoint // ro access
closed bool
Expand Down Expand Up @@ -121,7 +121,7 @@ func (c *conn) park(ctx context.Context) (err error) {
return nil
}

if c.cc == nil {
if c.grpcConn == nil {
return nil
}

Expand Down Expand Up @@ -161,7 +161,7 @@ func (c *conn) setState(ctx context.Context, s State) State {
func (c *conn) Unban(ctx context.Context) State {
var newState State
c.mtx.RLock()
cc := c.cc
cc := c.grpcConn
c.mtx.RUnlock()
if isAvailable(cc) {
newState = Online
Expand All @@ -186,8 +186,8 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) {
c.mtx.Lock()
defer c.mtx.Unlock()

if c.cc != nil {
return c.cc, nil
if c.grpcConn != nil {
return c.grpcConn, nil
}

if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 {
Expand Down Expand Up @@ -234,10 +234,10 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) {
)
}

c.cc = cc
c.grpcConn = cc
c.setState(ctx, Online)

return c.cc, nil
return c.grpcConn, nil
}

func (c *conn) onTransportError(ctx context.Context, cause error) {
Expand All @@ -252,11 +252,11 @@ func isAvailable(raw *grpc.ClientConn) bool {

// conn must be locked
func (c *conn) close(ctx context.Context) (err error) {
if c.cc == nil {
if c.grpcConn == nil {
return nil
}
err = c.cc.Close()
c.cc = nil
err = c.grpcConn.Close()
c.grpcConn = nil
c.setState(ctx, Offline)

return c.wrapError(err)
Expand Down Expand Up @@ -423,19 +423,23 @@ func (c *conn) NewStream(

ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID))

ctx, cancel := xcontext.WithCancel(ctx)
ctx, cancel := c.childStreams.WithCancel(ctx)
defer func() {
if finalErr != nil {
cancel()
} else {
c.childStreams.Remember(&cancel)
}
}()

s, err := cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(func(err error) {
cancel()
c.childStreams.Forget(&cancel)
}))...)
s := &grpcClientStream{
parentConn: c,
streamCtx: ctx,
streamCancel: cancel,
wrapping: useWrapping,
traceID: traceID,
sentMark: sentMark,
}

s.stream, err = cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(s.finish))...)
if err != nil {
if xerrors.IsContextError(err) {
return nil, xerrors.WithStackTrace(err)
Expand All @@ -451,25 +455,16 @@ func (c *conn) NewStream(
xerrors.WithTraceID(traceID),
)
if sentMark.canRetry() {
return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream")))
return nil, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream")))
}

return s, c.wrapError(err)
return nil, c.wrapError(err)
}

return s, err
return nil, err
}

return &grpcClientStream{
ClientStream: s,
c: c,
wrapping: useWrapping,
traceID: traceID,
sentMark: sentMark,
onDone: func(ctx context.Context, md metadata.MD) {
meta.CallTrailerCallback(ctx, md)
},
}, nil
return s, nil
}

func (c *conn) wrapError(err error) error {
Expand Down
81 changes: 49 additions & 32 deletions internal/conn/grpc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,50 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/wrap"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
)

type grpcClientStream struct {
grpc.ClientStream
c *conn
wrapping bool
traceID string
sentMark *modificationMark
onDone func(ctx context.Context, md metadata.MD)
parentConn *conn
stream grpc.ClientStream
streamCtx context.Context //nolint:containedctx
streamCancel context.CancelFunc
wrapping bool
traceID string
sentMark *modificationMark
}

func (s *grpcClientStream) Header() (metadata.MD, error) {
return s.stream.Header()
}

func (s *grpcClientStream) Trailer() metadata.MD {
return s.stream.Trailer()
}

func (s *grpcClientStream) Context() context.Context {
return s.stream.Context()
}

func (s *grpcClientStream) CloseSend() (err error) {
var (
ctx = s.Context()
onDone = trace.DriverOnConnStreamCloseSend(s.c.config.Trace(), &ctx,
ctx = s.streamCtx
onDone = trace.DriverOnConnStreamCloseSend(s.parentConn.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*grpcClientStream).CloseSend"),
)
)
defer func() {
onDone(err)
}()

stop := s.c.lastUsage.Start()
stop := s.parentConn.lastUsage.Start()
defer stop()

err = s.ClientStream.CloseSend()
err = s.stream.CloseSend()

if err != nil {
if xerrors.IsContextError(err) {
Expand All @@ -48,7 +62,7 @@ func (s *grpcClientStream) CloseSend() (err error) {
return s.wrapError(
xerrors.Transport(
err,
xerrors.WithAddress(s.c.Address()),
xerrors.WithAddress(s.parentConn.Address()),
xerrors.WithTraceID(s.traceID),
),
)
Expand All @@ -62,32 +76,32 @@ func (s *grpcClientStream) CloseSend() (err error) {

func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
var (
ctx = s.Context()
onDone = trace.DriverOnConnStreamSendMsg(s.c.config.Trace(), &ctx,
ctx = s.streamCtx
onDone = trace.DriverOnConnStreamSendMsg(s.parentConn.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*grpcClientStream).SendMsg"),
)
)
defer func() {
onDone(err)
}()

stop := s.c.lastUsage.Start()
stop := s.parentConn.lastUsage.Start()
defer stop()

err = s.ClientStream.SendMsg(m)
err = s.stream.SendMsg(m)

if err != nil {
if xerrors.IsContextError(err) {
return xerrors.WithStackTrace(err)
}

defer func() {
s.c.onTransportError(ctx, err)
s.parentConn.onTransportError(ctx, err)
}()

if s.wrapping {
err = xerrors.Transport(err,
xerrors.WithAddress(s.c.Address()),
xerrors.WithAddress(s.parentConn.Address()),
xerrors.WithTraceID(s.traceID),
)
if s.sentMark.canRetry() {
Expand All @@ -105,28 +119,31 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) {
return nil
}

func (s *grpcClientStream) finish(err error) {
s.streamCancel()
trace.DriverOnConnStreamFinish(s.parentConn.config.Trace(), s.streamCtx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*grpcClientStream).finish"), err,
)
}

func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
var (
ctx = s.Context()
onDone = trace.DriverOnConnStreamRecvMsg(s.c.config.Trace(), &ctx,
ctx = s.streamCtx
onDone = trace.DriverOnConnStreamRecvMsg(s.parentConn.config.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*grpcClientStream).RecvMsg"),
)
)
defer func() {
onDone(err)
}()

stop := s.c.lastUsage.Start()
defer stop()

defer func() {
if err != nil {
md := s.ClientStream.Trailer()
s.onDone(ctx, md)
meta.CallTrailerCallback(s.streamCtx, s.stream.Trailer())
}
}()

err = s.ClientStream.RecvMsg(m)
stop := s.parentConn.lastUsage.Start()
defer stop()

err = s.stream.RecvMsg(m)

if err != nil { //nolint:nestif
if xerrors.IsContextError(err) {
Expand All @@ -135,13 +152,13 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {

defer func() {
if !xerrors.Is(err, io.EOF) {
s.c.onTransportError(ctx, err)
s.parentConn.onTransportError(ctx, err)
}
}()

if s.wrapping {
err = xerrors.Transport(err,
xerrors.WithAddress(s.c.Address()),
xerrors.WithAddress(s.parentConn.Address()),
)
if s.sentMark.canRetry() {
return s.wrapError(xerrors.Retryable(err,
Expand All @@ -161,7 +178,7 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) {
return s.wrapError(
xerrors.Operation(
xerrors.FromOperation(operation),
xerrors.WithAddress(s.c.Address()),
xerrors.WithAddress(s.parentConn.Address()),
),
)
}
Expand All @@ -177,7 +194,7 @@ func (s *grpcClientStream) wrapError(err error) error {
}

return xerrors.WithStackTrace(
newConnError(s.c.endpoint.NodeID(), s.c.endpoint.Address(), err),
newConnError(s.parentConn.endpoint.NodeID(), s.parentConn.endpoint.Address(), err),
xerrors.WithSkipDepth(1),
)
}
26 changes: 15 additions & 11 deletions internal/xcontext/cancels_quard.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@ import (
"sync"
)

type CancelsGuard struct {
mu sync.Mutex
cancels map[*context.CancelFunc]struct{}
}
type (
CancelsGuard struct {
mu sync.Mutex
cancels map[*context.CancelFunc]struct{}
}
)

func NewCancelsGuard() *CancelsGuard {
return &CancelsGuard{
cancels: make(map[*context.CancelFunc]struct{}),
}
}

func (g *CancelsGuard) Remember(cancel *context.CancelFunc) {
func (g *CancelsGuard) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) {
g.mu.Lock()
defer g.mu.Unlock()
g.cancels[cancel] = struct{}{}
}
ctx, cancel := WithCancel(ctx)
g.cancels[&cancel] = struct{}{}

func (g *CancelsGuard) Forget(cancel *context.CancelFunc) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.cancels, cancel)
return ctx, func() {
cancel()
g.mu.Lock()
defer g.mu.Unlock()
delete(g.cancels, &cancel)
}
}

func (g *CancelsGuard) Cancel() {
Expand Down
13 changes: 6 additions & 7 deletions internal/xcontext/cancels_quard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@ import (

func TestCancelsGuard(t *testing.T) {
g := NewCancelsGuard()
ctx, cancel1 := context.WithCancel(context.Background())
g.Remember(&cancel1)
ctx, cancel1 := g.WithCancel(context.Background())
require.Len(t, g.cancels, 1)
g.Forget(&cancel1)
cancel1()
require.Error(t, ctx.Err())
require.Empty(t, g.cancels, 0)
cancel2 := context.CancelFunc(func() {
cancel1()
})
g.Remember(&cancel2)
ctx, _ = g.WithCancel(context.Background())
require.Len(t, g.cancels, 1)
ctx, _ = g.WithCancel(ctx)
require.Len(t, g.cancels, 2)
g.Cancel()
require.Error(t, ctx.Err())
}
Loading

0 comments on commit 8962b78

Please sign in to comment.