diff --git a/client/v3/concurrency/session.go b/client/v3/concurrency/session.go index 8838b77e2d7..2275e96c972 100644 --- a/client/v3/concurrency/session.go +++ b/client/v3/concurrency/session.go @@ -32,6 +32,7 @@ type Session struct { opts *sessionOptions id v3.LeaseID + ctx context.Context cancel context.CancelFunc donec <-chan struct{} } @@ -61,11 +62,14 @@ func NewSession(client *v3.Client, opts ...SessionOption) (*Session, error) { } donec := make(chan struct{}) - s := &Session{client: client, opts: ops, id: id, cancel: cancel, donec: donec} + s := &Session{client: client, opts: ops, id: id, ctx: ctx, cancel: cancel, donec: donec} // keep the lease alive until client error or cancelled context go func() { - defer close(donec) + defer func() { + close(donec) + cancel() + }() for range keepAlive { // eat messages until keep alive channel closes } @@ -82,6 +86,12 @@ func (s *Session) Client() *v3.Client { // Lease is the lease ID for keys bound to the session. func (s *Session) Lease() v3.LeaseID { return s.id } +// Ctx is the context attached to the session, it is canceled when the lease is orphaned, expires, or +// is otherwise no longer being refreshed. +func (s *Session) Ctx() context.Context { + return s.ctx +} + // Done returns a channel that closes when the lease is orphaned, expires, or // is otherwise no longer being refreshed. func (s *Session) Done() <-chan struct{} { return s.donec } diff --git a/tests/integration/clientv3/concurrency/session_test.go b/tests/integration/clientv3/concurrency/session_test.go index d1ca413200d..b1799117975 100644 --- a/tests/integration/clientv3/concurrency/session_test.go +++ b/tests/integration/clientv3/concurrency/session_test.go @@ -82,3 +82,32 @@ func TestSessionTTLOptions(t *testing.T) { } } + +func TestSessionCtx(t *testing.T) { + cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: exampleEndpoints()}) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + lease, err := cli.Grant(context.Background(), 100) + if err != nil { + t.Fatal(err) + } + s, err := concurrency.NewSession(cli, concurrency.WithLease(lease.ID)) + if err != nil { + t.Fatal(err) + } + defer s.Close() + assert.Equal(t, s.Lease(), lease.ID) + + childCtx, cancel := context.WithCancel(s.Ctx()) + defer cancel() + + go s.Orphan() + select { + case <-childCtx.Done(): + case <-time.After(time.Millisecond * 100): + t.Fatal("child context of session context is not canceled") + } + assert.Equal(t, childCtx.Err(), context.Canceled) +}