From f0f4fd4102ce58743536d899dacf527af3e67709 Mon Sep 17 00:00:00 2001 From: Aleksey Myasnikov Date: Fri, 29 Mar 2024 19:44:55 +0300 Subject: [PATCH] Refactoring --- balancers/balancers.go | 10 +- balancers/balancers_test.go | 46 +- balancers/config_test.go | 8 +- driver.go | 2 +- internal/balancer/balancer.go | 66 ++- internal/balancer/balancer_test.go | 38 +- internal/balancer/config/routerconfig.go | 17 +- internal/balancer/connections_state.go | 101 ++-- internal/balancer/connections_state_test.go | 373 +++++++------- internal/balancer/errors.go | 9 + internal/conn/cancels_quard.go | 40 ++ internal/conn/cancels_quard_test.go | 24 + internal/conn/cc_guard.go | 116 +++++ internal/conn/conn.go | 464 ++++++++---------- internal/conn/grpc_client_stream.go | 20 +- internal/conn/in_use_quard.go | 35 ++ internal/conn/in_use_quard_test.go | 52 ++ internal/conn/last_usage.go | 47 -- internal/conn/last_usage_guard.go | 50 ++ ...usage_test.go => last_usage_guard_test.go} | 6 +- internal/conn/pool.go | 82 ++-- internal/conn/state.go | 11 +- internal/endpoint/endpoint.go | 3 +- internal/mock/conn.go | 125 ----- internal/mock/conn_info.go | 28 ++ internal/mock/endpoint.go | 54 ++ internal/pool/pool.go | 24 +- internal/table/client.go | 4 - internal/xerrors/pessimized_error_test.go | 106 ---- internal/xerrors/transport.go | 2 +- internal/xerrors/transport_test.go | 97 ++++ log/driver.go | 4 +- options.go | 2 +- tests/integration/discovery_test.go | 41 +- tests/slo/prometheus.yml | 7 + trace/driver.go | 34 +- trace/driver_gtrace.go | 148 +++++- 37 files changed, 1316 insertions(+), 980 deletions(-) create mode 100644 internal/balancer/errors.go create mode 100644 internal/conn/cancels_quard.go create mode 100644 internal/conn/cancels_quard_test.go create mode 100644 internal/conn/cc_guard.go create mode 100644 internal/conn/in_use_quard.go create mode 100644 internal/conn/in_use_quard_test.go delete mode 100644 internal/conn/last_usage.go create mode 100644 internal/conn/last_usage_guard.go rename internal/conn/{last_usage_test.go => last_usage_guard_test.go} (96%) delete mode 100644 internal/mock/conn.go create mode 100644 internal/mock/conn_info.go create mode 100644 internal/mock/endpoint.go delete mode 100644 internal/xerrors/pessimized_error_test.go create mode 100644 tests/slo/prometheus.yml diff --git a/balancers/balancers.go b/balancers/balancers.go index c73c8f503..8df46f729 100644 --- a/balancers/balancers.go +++ b/balancers/balancers.go @@ -26,7 +26,7 @@ func SingleConn() *balancerConfig.Config { type filterLocalDC struct{} -func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Conn) bool { +func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Info) bool { return c.Endpoint().Location() == info.SelfLocation } @@ -56,7 +56,7 @@ func PreferLocalDCWithFallBack(balancer *balancerConfig.Config) *balancerConfig. type filterLocations []string -func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Conn) bool { +func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Info) bool { location := strings.ToUpper(c.Endpoint().Location()) for _, l := range locations { if location == l { @@ -118,9 +118,9 @@ type Endpoint interface { LocalDC() bool } -type filterFunc func(info balancerConfig.Info, c conn.Conn) bool +type filterFunc func(info balancerConfig.Info, c conn.Info) bool -func (p filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool { +func (p filterFunc) Allow(info balancerConfig.Info, c conn.Info) bool { return p(info, c) } @@ -131,7 +131,7 @@ func (p filterFunc) String() string { // Prefer creates balancer which use endpoints by filter // Balancer "balancer" defines balancing algorithm between endpoints selected with filter func Prefer(balancer *balancerConfig.Config, filter func(endpoint Endpoint) bool) *balancerConfig.Config { - balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { + balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Info) bool { return filter(c.Endpoint()) }) diff --git a/balancers/balancers_test.go b/balancers/balancers_test.go index 75d0758b0..f7655f73d 100644 --- a/balancers/balancers_test.go +++ b/balancers/balancers_test.go @@ -11,56 +11,56 @@ import ( ) func TestPreferLocalDC(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", LocationField: "1"}, - &mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"}, - &mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"}, + &mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"}, } rr := PreferLocalDC(RandomChoice()) require.False(t, rr.AllowFallback) - require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns)) + require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns)) } func TestPreferLocalDCWithFallBack(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", LocationField: "1"}, - &mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"}, - &mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"}, + &mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"}, } rr := PreferLocalDCWithFallBack(RandomChoice()) require.True(t, rr.AllowFallback) - require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns)) + require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns)) } func TestPreferLocations(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"}, - &mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"}, + &mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"}, } rr := PreferLocations(RandomChoice(), "zero", "two") require.False(t, rr.AllowFallback) - require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns)) + require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns)) } func TestPreferLocationsWithFallback(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"}, - &mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"}, + &mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"}, } rr := PreferLocationsWithFallback(RandomChoice(), "zero", "two") require.True(t, rr.AllowFallback) - require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns)) + require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns)) } -func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Conn) []conn.Conn { +func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Info) []conn.Info { if b.Filter == nil { - b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { return true }) + b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return true }) } - res := make([]conn.Conn, 0, len(conns)) + res := make([]conn.Info, 0, len(conns)) for _, c := range conns { if b.Filter.Allow(info, c) { res = append(res, c) diff --git a/balancers/config_test.go b/balancers/config_test.go index 943e3f123..e153d4aa9 100644 --- a/balancers/config_test.go +++ b/balancers/config_test.go @@ -71,7 +71,7 @@ func TestFromConfig(t *testing.T) { }`, res: balancerConfig.Config{ DetectLocalDC: true, - Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool { // some non nil func return false }), @@ -95,7 +95,7 @@ func TestFromConfig(t *testing.T) { res: balancerConfig.Config{ AllowFallback: true, DetectLocalDC: true, - Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool { // some non nil func return false }), @@ -109,7 +109,7 @@ func TestFromConfig(t *testing.T) { "locations": ["AAA", "BBB", "CCC"] }`, res: balancerConfig.Config{ - Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool { // some non nil func return false }), @@ -125,7 +125,7 @@ func TestFromConfig(t *testing.T) { }`, res: balancerConfig.Config{ AllowFallback: true, - Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool { // some non nil func return false }), diff --git a/driver.go b/driver.go index 7adfab361..5d4a60618 100644 --- a/driver.go +++ b/driver.go @@ -152,7 +152,7 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) { d.query.Close, d.topic.Close, d.balancer.Close, - d.pool.Release, + d.pool.Detach, ) var issues []error diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index f69ec11e2..54d426fb1 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -41,7 +41,9 @@ type Balancer struct { localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) mu xsync.RWMutex - connectionsState *connectionsState + connectionsState *connectionsState[conn.Conn] + + closed chan struct{} onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) } @@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { return nil } -func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) ( +func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Info) ( nodes []trace.EndpointInfo, added []trace.EndpointInfo, dropped []trace.EndpointInfo, @@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), b.config.DetectLocalDC, ) - previousConns []conn.Conn + previousConns []conn.Info ) defer func() { nodes, added, dropped := endpointsDiff(endpoints, previousConns) @@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end connections := endpointsToConnections(b.pool, endpoints) for _, c := range connections { - b.pool.Allow(ctx, c) + if c.State() == conn.Banned { + b.pool.Unban(ctx, c) + } c.Endpoint().Touch() } @@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end b.mu.WithLock(func() { if b.connectionsState != nil { - previousConns = b.connectionsState.all + previousConns = make([]conn.Info, len(b.connectionsState.all)) + for i := range b.connectionsState.all { + previousConns[i] = b.connectionsState.all[i] + } } b.connectionsState = state for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints { @@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end } func (b *Balancer) Close(ctx context.Context) (err error) { + close(b.closed) + onDone := trace.DriverOnBalancerClose( b.driverConfig.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"), @@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) { b.discoveryRepeater.Stop() } + b.applyDiscoveredEndpoints(ctx, nil, "") + if err = b.discoveryClient.Close(ctx); err != nil { return xerrors.WithStackTrace(err) } @@ -258,6 +269,7 @@ func New( driverConfig: driverConfig, pool: pool, localDCDetector: detectLocalDC, + closed: make(chan struct{}), } d := internalDiscovery.New(ctx, pool.Get( endpoint.New(driverConfig.Endpoint()), @@ -300,9 +312,14 @@ func (b *Balancer) Invoke( reply interface{}, opts ...grpc.CallOption, ) error { - return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { - return cc.Invoke(ctx, method, args, reply, opts...) - }) + select { + case <-b.closed: + return xerrors.WithStackTrace(errBalancerClosed) + default: + return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { + return cc.Invoke(ctx, method, args, reply, opts...) + }) + } } func (b *Balancer) NewStream( @@ -311,17 +328,22 @@ func (b *Balancer) NewStream( method string, opts ...grpc.CallOption, ) (_ grpc.ClientStream, err error) { - var client grpc.ClientStream - err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { - client, err = cc.NewStream(ctx, desc, method, opts...) + select { + case <-b.closed: + return nil, xerrors.WithStackTrace(errBalancerClosed) + default: + var client grpc.ClientStream + err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { + client, err = cc.NewStream(ctx, desc, method, opts...) + + return err + }) + if err == nil { + return client, nil + } - return err - }) - if err == nil { - return client, nil + return nil, err } - - return nil, err } func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) { @@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc defer func() { if err == nil { - if cc.GetState() == conn.Banned { - b.pool.Allow(ctx, cc) - } - } else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { + b.pool.Unban(ctx, cc) + } else if xerrors.MustBanConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { b.pool.Ban(ctx, cc, err) } }() @@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc return nil } -func (b *Balancer) connections() *connectionsState { +func (b *Balancer) connections() *connectionsState[conn.Conn] { b.mu.RLock() defer b.mu.RUnlock() @@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { c, failedCount = state.GetConnection(ctx) if c == nil { return nil, xerrors.WithStackTrace( - fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), + fmt.Errorf("cannot get connection from Balancer after %d attempts: %w", failedCount, ErrNoEndpoints), ) } diff --git a/internal/balancer/balancer_test.go b/internal/balancer/balancer_test.go index 356952f38..b1308af45 100644 --- a/internal/balancer/balancer_test.go +++ b/internal/balancer/balancer_test.go @@ -15,7 +15,7 @@ import ( func TestEndpointsDiff(t *testing.T) { for _, tt := range []struct { newestEndpoints []endpoint.Endpoint - previousConns []conn.Conn + previousConns []conn.Info nodes []trace.EndpointInfo added []trace.EndpointInfo dropped []trace.EndpointInfo @@ -27,11 +27,11 @@ func TestEndpointsDiff(t *testing.T) { &mock.Endpoint{AddrField: "2"}, &mock.Endpoint{AddrField: "0"}, }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "2"}, - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, + previousConns: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "2"}, + &mock.ConnInfo{EndpointAddrField: "1"}, + &mock.ConnInfo{EndpointAddrField: "0"}, + &mock.ConnInfo{EndpointAddrField: "3"}, }, nodes: []trace.EndpointInfo{ &mock.Endpoint{AddrField: "0"}, @@ -49,10 +49,10 @@ func TestEndpointsDiff(t *testing.T) { &mock.Endpoint{AddrField: "2"}, &mock.Endpoint{AddrField: "0"}, }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, + previousConns: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1"}, + &mock.ConnInfo{EndpointAddrField: "0"}, + &mock.ConnInfo{EndpointAddrField: "3"}, }, nodes: []trace.EndpointInfo{ &mock.Endpoint{AddrField: "0"}, @@ -71,11 +71,11 @@ func TestEndpointsDiff(t *testing.T) { &mock.Endpoint{AddrField: "3"}, &mock.Endpoint{AddrField: "0"}, }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - &mock.Conn{AddrField: "0"}, - &mock.Conn{AddrField: "3"}, + previousConns: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1"}, + &mock.ConnInfo{EndpointAddrField: "2"}, + &mock.ConnInfo{EndpointAddrField: "0"}, + &mock.ConnInfo{EndpointAddrField: "3"}, }, nodes: []trace.EndpointInfo{ &mock.Endpoint{AddrField: "0"}, @@ -93,10 +93,10 @@ func TestEndpointsDiff(t *testing.T) { &mock.Endpoint{AddrField: "3"}, &mock.Endpoint{AddrField: "0"}, }, - previousConns: []conn.Conn{ - &mock.Conn{AddrField: "4"}, - &mock.Conn{AddrField: "7"}, - &mock.Conn{AddrField: "8"}, + previousConns: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "4"}, + &mock.ConnInfo{EndpointAddrField: "7"}, + &mock.ConnInfo{EndpointAddrField: "8"}, }, nodes: []trace.EndpointInfo{ &mock.Endpoint{AddrField: "0"}, diff --git a/internal/balancer/config/routerconfig.go b/internal/balancer/config/routerconfig.go index 0d1eb6703..a31c2f00e 100644 --- a/internal/balancer/config/routerconfig.go +++ b/internal/balancer/config/routerconfig.go @@ -42,11 +42,12 @@ func (c Config) String() string { return buffer.String() } -type Info struct { - SelfLocation string -} - -type Filter interface { - Allow(info Info, c conn.Conn) bool - String() string -} +type ( + Info struct { + SelfLocation string + } + Filter interface { + Allow(info Info, c conn.Info) bool + String() string + } +) diff --git a/internal/balancer/connections_state.go b/internal/balancer/connections_state.go index e9196ead7..a4acebcc0 100644 --- a/internal/balancer/connections_state.go +++ b/internal/balancer/connections_state.go @@ -8,23 +8,23 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" ) -type connectionsState struct { - connByNodeID map[uint32]conn.Conn +type connectionsState[T conn.Info] struct { + connByNodeID map[uint32]T - prefer []conn.Conn - fallback []conn.Conn - all []conn.Conn + prefer []T + fallback []T + all []T rand xrand.Rand } -func newConnectionsState( - conns []conn.Conn, +func newConnectionsState[T conn.Info]( + conns []T, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, -) *connectionsState { - res := &connectionsState{ +) *connectionsState[T] { + res := &connectionsState[T]{ connByNodeID: connsToNodeIDMap(conns), rand: xrand.New(xrand.WithLock()), } @@ -39,60 +39,63 @@ func newConnectionsState( return res } -func (s *connectionsState) PreferredCount() int { +func (s *connectionsState[T]) PreferredCount() int { return len(s.prefer) } -func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) { +func (s *connectionsState[T]) GetConnection(ctx context.Context) (nilConn T, failedCount int) { if err := ctx.Err(); err != nil { - return nil, 0 + return nilConn, 0 } - if c := s.preferConnection(ctx); c != nil { - return c, 0 + if cc, has := s.preferConnection(ctx); has { + return cc, 0 } - try := func(conns []conn.Conn) conn.Conn { - c, tryFailed := s.selectRandomConnection(conns, false) - failedCount += tryFailed - - return c + cc, tryCount, has := selectRandomConnection(s.rand, s.prefer, false) + failedCount += tryCount + if has { + return cc, failedCount } - if c := try(s.prefer); c != nil { - return c, failedCount + cc, tryCount, has = selectRandomConnection(s.rand, s.fallback, false) + failedCount += tryCount + if has { + return cc, failedCount } - if c := try(s.fallback); c != nil { - return c, failedCount + cc, tryCount, _ = selectRandomConnection(s.rand, s.all, true) + failedCount += tryCount + if has { + return cc, failedCount } - c, _ := s.selectRandomConnection(s.all, true) - - return c, failedCount + return nilConn, failedCount } -func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn { +func (s *connectionsState[T]) preferConnection(ctx context.Context) (nilConn T, has bool) { if e, hasPreferEndpoint := ContextEndpoint(ctx); hasPreferEndpoint { - c := s.connByNodeID[e.NodeID()] - if c != nil && isOkConnection(c, true) { - return c + cc, ok := s.connByNodeID[e.NodeID()] + if ok && isOkConnection(cc, true) { + return cc, true } } - return nil + return nilConn, false } -func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) { +func selectRandomConnection[T conn.Info]( + r xrand.Rand, conns []T, allowBanned bool, +) (nilConn T, failedConns int, has bool) { connCount := len(conns) if connCount == 0 { // return for empty list need for prevent panic in fast path - return nil, 0 + return nilConn, 0, false } // fast path - if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) { - return c, 0 + if cc := conns[r.Int(connCount)]; isOkConnection(cc, allowBanned) { + return cc, 0, true } // shuffled indexes slices need for guarantee about every connection will check @@ -100,26 +103,26 @@ func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned for index := range indexes { indexes[index] = index } - s.rand.Shuffle(connCount, func(i, j int) { + r.Shuffle(connCount, func(i, j int) { indexes[i], indexes[j] = indexes[j], indexes[i] }) for _, index := range indexes { - c := conns[index] - if isOkConnection(c, allowBanned) { - return c, 0 + cc := conns[index] + if isOkConnection(cc, allowBanned) { + return cc, 0, true } failedConns++ } - return nil, failedConns + return nilConn, failedConns, false } -func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { +func connsToNodeIDMap[T conn.Info](conns []T) (nodes map[uint32]T) { if len(conns) == 0 { return nil } - nodes = make(map[uint32]conn.Conn, len(conns)) + nodes = make(map[uint32]T, len(conns)) for _, c := range conns { nodes[c.Endpoint().NodeID()] = c } @@ -127,19 +130,19 @@ func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { return nodes } -func sortPreferConnections( - conns []conn.Conn, +func sortPreferConnections[T conn.Info]( + conns []T, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, -) (prefer, fallback []conn.Conn) { +) (prefer, fallback []T) { if filter == nil { return conns, nil } - prefer = make([]conn.Conn, 0, len(conns)) + prefer = make([]T, 0, len(conns)) if allowFallback { - fallback = make([]conn.Conn, 0, len(conns)) + fallback = make([]T, 0, len(conns)) } for _, c := range conns { @@ -153,8 +156,8 @@ func sortPreferConnections( return prefer, fallback } -func isOkConnection(c conn.Conn, bannedIsOk bool) bool { - switch c.GetState() { +func isOkConnection[T conn.Info](c T, bannedIsOk bool) bool { + switch c.State() { case conn.Online, conn.Created, conn.Offline: return true case conn.Banned: diff --git a/internal/balancer/connections_state_test.go b/internal/balancer/connections_state_test.go index b052b3933..6ea627e0b 100644 --- a/internal/balancer/connections_state_test.go +++ b/internal/balancer/connections_state_test.go @@ -10,13 +10,14 @@ import ( balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" ) func TestConnsToNodeIDMap(t *testing.T) { table := []struct { name string - source []conn.Conn - res map[uint32]conn.Conn + source []conn.Info + res map[uint32]conn.Info }{ { name: "Empty", @@ -25,35 +26,35 @@ func TestConnsToNodeIDMap(t *testing.T) { }, { name: "Zero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 0}, + source: []conn.Info{ + &mock.ConnInfo{EndpointNodeIDField: 0}, }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, + res: map[uint32]conn.Info{ + 0: &mock.ConnInfo{EndpointNodeIDField: 0}, }, }, { name: "NonZero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 10}, + source: []conn.Info{ + &mock.ConnInfo{EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointNodeIDField: 10}, }, - res: map[uint32]conn.Conn{ - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, + res: map[uint32]conn.Info{ + 1: &mock.ConnInfo{EndpointNodeIDField: 1}, + 10: &mock.ConnInfo{EndpointNodeIDField: 10}, }, }, { name: "Combined", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 0}, - &mock.Conn{NodeIDField: 10}, + source: []conn.Info{ + &mock.ConnInfo{EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointNodeIDField: 0}, + &mock.ConnInfo{EndpointNodeIDField: 10}, }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, + res: map[uint32]conn.Info{ + 0: &mock.ConnInfo{EndpointNodeIDField: 0}, + 1: &mock.ConnInfo{EndpointNodeIDField: 1}, + 10: &mock.ConnInfo{EndpointNodeIDField: 10}, }, }, } @@ -65,9 +66,9 @@ func TestConnsToNodeIDMap(t *testing.T) { } } -type filterFunc func(info balancerConfig.Info, c conn.Conn) bool +type filterFunc func(info balancerConfig.Info, c conn.Info) bool -func (f filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool { +func (f filterFunc) Allow(info balancerConfig.Info, c conn.Info) bool { return f(info, c) } @@ -78,11 +79,11 @@ func (f filterFunc) String() string { func TestSortPreferConnections(t *testing.T) { table := []struct { name string - source []conn.Conn + source []conn.Info allowFallback bool filter balancerConfig.Filter - prefer []conn.Conn - fallback []conn.Conn + prefer []conn.Info + fallback []conn.Info }{ { name: "Empty", @@ -94,55 +95,55 @@ func TestSortPreferConnections(t *testing.T) { }, { name: "NilFilter", - source: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, + source: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1"}, + &mock.ConnInfo{EndpointAddrField: "2"}, }, allowFallback: false, filter: nil, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1"}, + &mock.ConnInfo{EndpointAddrField: "2"}, }, fallback: nil, }, { name: "FilterNoFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, + source: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1"}, + &mock.ConnInfo{EndpointAddrField: "f1"}, + &mock.ConnInfo{EndpointAddrField: "t2"}, + &mock.ConnInfo{EndpointAddrField: "f2"}, }, allowFallback: false, - filter: filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { + filter: filterFunc(func(_ balancerConfig.Info, c conn.Info) bool { return strings.HasPrefix(c.Endpoint().Address(), "t") }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1"}, + &mock.ConnInfo{EndpointAddrField: "t2"}, }, fallback: nil, }, { name: "FilterWithFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, + source: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1"}, + &mock.ConnInfo{EndpointAddrField: "f1"}, + &mock.ConnInfo{EndpointAddrField: "t2"}, + &mock.ConnInfo{EndpointAddrField: "f2"}, }, allowFallback: true, - filter: filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { + filter: filterFunc(func(_ balancerConfig.Info, c conn.Info) bool { return strings.HasPrefix(c.Endpoint().Address(), "t") }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1"}, + &mock.ConnInfo{EndpointAddrField: "t2"}, }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "f2"}, + fallback: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "f1"}, + &mock.ConnInfo{EndpointAddrField: "f2"}, }, }, } @@ -157,39 +158,52 @@ func TestSortPreferConnections(t *testing.T) { } func TestSelectRandomConnection(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) + r := xrand.New(xrand.WithLock()) t.Run("Empty", func(t *testing.T) { - c, failedCount := s.selectRandomConnection(nil, false) + c, failedCount, has := selectRandomConnection[conn.Info](r, nil, false) + require.False(t, has) require.Nil(t, c) require.Equal(t, 0, failedCount) }) t.Run("One", func(t *testing.T) { for _, goodState := range []conn.State{conn.Online, conn.Offline, conn.Created} { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: goodState}}, false) - require.Equal(t, &mock.Conn{AddrField: "asd", State: goodState}, c) + c, failedCount, has := selectRandomConnection(r, + []conn.Info{&mock.ConnInfo{EndpointAddrField: "asd", ConnState: goodState}}, false, + ) + require.True(t, has) + require.NotNil(t, c) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "asd", ConnState: goodState}, c) require.Equal(t, 0, failedCount) } }) t.Run("OneBanned", func(t *testing.T) { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, false) + c, failedCount, has := selectRandomConnection(r, + []conn.Info{&mock.ConnInfo{EndpointAddrField: "asd", ConnState: conn.Banned}}, false, + ) + require.False(t, has) require.Nil(t, c) require.Equal(t, 1, failedCount) - c, failedCount = s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, true) - require.Equal(t, &mock.Conn{AddrField: "asd", State: conn.Banned}, c) + c, failedCount, has = selectRandomConnection(r, + []conn.Info{&mock.ConnInfo{EndpointAddrField: "asd", ConnState: conn.Banned}}, true, + ) + require.True(t, has) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "asd", ConnState: conn.Banned}, c) require.Equal(t, 0, failedCount) }) t.Run("Two", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online}, } first := 0 second := 0 for i := 0; i < 100; i++ { - c, _ := s.selectRandomConnection(conns, false) + c, _, has := selectRandomConnection(r, conns, false) + require.True(t, has) + require.NotNil(t, c) if c.Endpoint().Address() == "1" { first++ } else { @@ -201,30 +215,33 @@ func TestSelectRandomConnection(t *testing.T) { require.InDelta(t, 50, second, 21) }) t.Run("TwoBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Banned}, - &mock.Conn{AddrField: "2", State: conn.Banned}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Banned}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Banned}, } totalFailed := 0 for i := 0; i < 100; i++ { - c, failed := s.selectRandomConnection(conns, false) + c, failed, has := selectRandomConnection(r, conns, false) + require.False(t, has) require.Nil(t, c) totalFailed += failed } require.Equal(t, 200, totalFailed) }) t.Run("ThreeWithBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - &mock.Conn{AddrField: "3", State: conn.Banned}, + conns := []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Banned}, } first := 0 second := 0 failed := 0 for i := 0; i < 100; i++ { - c, checkFailed := s.selectRandomConnection(conns, false) + c, checkFailed, has := selectRandomConnection(r, conns, false) failed += checkFailed + require.True(t, has) + require.NotNil(t, c) switch c.Endpoint().Address() { case "1": first++ @@ -244,13 +261,13 @@ func TestSelectRandomConnection(t *testing.T) { func TestNewState(t *testing.T) { table := []struct { name string - state *connectionsState - res *connectionsState + state *connectionsState[conn.Info] + res *connectionsState[conn.Info] }{ { name: "Empty", - state: newConnectionsState(nil, nil, balancerConfig.Info{}, false), - res: &connectionsState{ + state: newConnectionsState[conn.Info](nil, nil, balancerConfig.Info{}, false), + res: &connectionsState[conn.Info]{ connByNodeID: nil, prefer: nil, fallback: nil, @@ -259,117 +276,117 @@ func TestNewState(t *testing.T) { }, { name: "NoFilter", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, + state: newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointAddrField: "2", EndpointNodeIDField: 2}, }, nil, balancerConfig.Info{}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "1", NodeIDField: 1}, - 2: &mock.Conn{AddrField: "2", NodeIDField: 2}, + res: &connectionsState[conn.Info]{ + connByNodeID: map[uint32]conn.Info{ + 1: &mock.ConnInfo{EndpointAddrField: "1", EndpointNodeIDField: 1}, + 2: &mock.ConnInfo{EndpointAddrField: "2", EndpointNodeIDField: 2}, }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointAddrField: "2", EndpointNodeIDField: 2}, }, fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, + all: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointAddrField: "2", EndpointNodeIDField: 2}, }, }, }, { name: "FilterDenyFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + state: newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + res: &connectionsState[conn.Info]{ + connByNodeID: map[uint32]conn.Info{ + 1: &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + 2: &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + 3: &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + 4: &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, }, fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + all: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, }, }, }, { name: "FilterAllowFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + state: newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + res: &connectionsState[conn.Info]{ + connByNodeID: map[uint32]conn.Info{ + 1: &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + 2: &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + 3: &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + 4: &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + fallback: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + all: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, }, }, { name: "WithNodeID", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + state: newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return info.SelfLocation == c.Endpoint().Location() }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + res: &connectionsState[conn.Info]{ + connByNodeID: map[uint32]conn.Info{ + 1: &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + 2: &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + 3: &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + 4: &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, + prefer: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + fallback: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, + all: []conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", EndpointNodeIDField: 1, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f1", EndpointNodeIDField: 2, EndpointLocationField: "f"}, + &mock.ConnInfo{EndpointAddrField: "t2", EndpointNodeIDField: 3, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", EndpointNodeIDField: 4, EndpointLocationField: "f"}, }, }, }, @@ -386,33 +403,33 @@ func TestNewState(t *testing.T) { func TestConnection(t *testing.T) { t.Run("Empty", func(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) + s := newConnectionsState[conn.Info](nil, nil, balancerConfig.Info{}, false) c, failed := s.GetConnection(context.Background()) require.Nil(t, c) require.Equal(t, 0, failed) }) t.Run("AllGood", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online}, }, nil, balancerConfig.Info{}, false) c, failed := s.GetConnection(context.Background()) require.NotNil(t, c) require.Equal(t, 0, failed) }) t.Run("WithBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Banned}, + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Banned}, }, nil, balancerConfig.Info{}, false) c, _ := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online}, c) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online}, c) }) t.Run("AllBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Banned, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", ConnState: conn.Banned, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", ConnState: conn.Banned, EndpointLocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return c.Endpoint().Location() == info.SelfLocation }), balancerConfig.Info{}, true) preferred := 0 @@ -432,32 +449,32 @@ func TestConnection(t *testing.T) { require.InDelta(t, 50, fallback, 21) }) t.Run("PreferBannedWithFallback", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "t1", ConnState: conn.Banned, EndpointLocationField: "t"}, + &mock.ConnInfo{EndpointAddrField: "f2", ConnState: conn.Online, EndpointLocationField: "f"}, + }, filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return c.Endpoint().Location() == info.SelfLocation }), balancerConfig.Info{SelfLocation: "t"}, true) c, failed := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "f2", ConnState: conn.Online, EndpointLocationField: "f"}, c) require.Equal(t, 1, failed) }) t.Run("PreferNodeID", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online, EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointNodeIDField: 2}, }, nil, balancerConfig.Info{}, false) c, failed := s.GetConnection(WithEndpoint(context.Background(), &mock.Endpoint{AddrField: "2", NodeIDField: 2})) - require.Equal(t, &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, c) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointNodeIDField: 2}, c) require.Equal(t, 0, failed) }) t.Run("PreferNodeIDWithBadState", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Unknown, NodeIDField: 2}, + s := newConnectionsState([]conn.Info{ + &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online, EndpointNodeIDField: 1}, + &mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Unknown, EndpointNodeIDField: 2}, }, nil, balancerConfig.Info{}, false) c, failed := s.GetConnection(WithEndpoint(context.Background(), &mock.Endpoint{AddrField: "2", NodeIDField: 2})) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, c) + require.Equal(t, &mock.ConnInfo{EndpointAddrField: "1", ConnState: conn.Online, EndpointNodeIDField: 1}, c) require.Equal(t, 0, failed) }) } diff --git a/internal/balancer/errors.go b/internal/balancer/errors.go new file mode 100644 index 000000000..190507489 --- /dev/null +++ b/internal/balancer/errors.go @@ -0,0 +1,9 @@ +package balancer + +import ( + "errors" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" +) + +var errBalancerClosed = xerrors.Wrap(errors.New("balancer closed")) diff --git a/internal/conn/cancels_quard.go b/internal/conn/cancels_quard.go new file mode 100644 index 000000000..c932c6d0e --- /dev/null +++ b/internal/conn/cancels_quard.go @@ -0,0 +1,40 @@ +package conn + +import ( + "context" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" +) + +type cancelsGuard struct { + mu xsync.Mutex + cancels map[*context.CancelFunc]struct{} +} + +func newCancelsGuard() *cancelsGuard { + return &cancelsGuard{ + mu: xsync.Mutex{}, + cancels: make(map[*context.CancelFunc]struct{}), + } +} + +func (g *cancelsGuard) Remember(cancel *context.CancelFunc) { + g.mu.Lock() + defer g.mu.Unlock() + g.cancels[cancel] = struct{}{} +} + +func (g *cancelsGuard) Forget(cancel *context.CancelFunc) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.cancels, cancel) +} + +func (g *cancelsGuard) Cancel() { + g.mu.Lock() + defer g.mu.Unlock() + for cancel := range g.cancels { + (*cancel)() + } + g.cancels = make(map[*context.CancelFunc]struct{}) +} diff --git a/internal/conn/cancels_quard_test.go b/internal/conn/cancels_quard_test.go new file mode 100644 index 000000000..df39125ee --- /dev/null +++ b/internal/conn/cancels_quard_test.go @@ -0,0 +1,24 @@ +package conn + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/net/context" +) + +func TestCancelsGuard(t *testing.T) { + g := newCancelsGuard() + ctx, cancel1 := context.WithCancel(context.Background()) + g.Remember(&cancel1) + require.Len(t, g.cancels, 1) + g.Forget(&cancel1) + require.Empty(t, g.cancels, 0) + cancel2 := context.CancelFunc(func() { + cancel1() + }) + g.Remember(&cancel2) + require.Len(t, g.cancels, 1) + g.Cancel() + require.Error(t, ctx.Err()) +} diff --git a/internal/conn/cc_guard.go b/internal/conn/cc_guard.go new file mode 100644 index 000000000..21d251939 --- /dev/null +++ b/internal/conn/cc_guard.go @@ -0,0 +1,116 @@ +package conn + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" +) + +type ccGuard struct { + dial func(ctx context.Context) (cc *grpc.ClientConn, err error) + cc *grpc.ClientConn + s State + mu sync.RWMutex +} + +func newCcGuard(dial func(ctx context.Context) (cc *grpc.ClientConn, err error)) *ccGuard { + return &ccGuard{ + dial: dial, + s: Created, + } +} + +func (g *ccGuard) Get(ctx context.Context) (*grpc.ClientConn, error) { + g.mu.Lock() + defer g.mu.Unlock() + + if g.cc != nil { + return g.cc, nil + } + + cc, err := g.dial(ctx) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + g.cc = cc + g.s = Online + + return cc, nil +} + +func (g *ccGuard) Ready() bool { + g.mu.Lock() + defer g.mu.Unlock() + + return g.ready() +} + +func (g *ccGuard) ready() bool { + return g.cc != nil && g.cc.GetState() == connectivity.Ready +} + +func (g *ccGuard) Unban() State { + g.mu.Lock() + defer g.mu.Unlock() + + if g.ready() { + g.s = Online + } else { + g.s = Offline + } + + return g.s +} + +func (g *ccGuard) Ban() State { + g.mu.Lock() + defer g.mu.Unlock() + + g.s = Banned + + return g.s +} + +func (g *ccGuard) State() State { + g.mu.RLock() + defer g.mu.RUnlock() + + return g.s +} + +func (g *ccGuard) Close(ctx context.Context) error { + g.mu.Lock() + defer g.mu.Unlock() + + if g.cc == nil { + return nil + } + + err := g.cc.Close() + g.cc = nil + g.s = Closed + + if err != nil { + return xerrors.WithStackTrace(err) + } + + return nil +} + +func (g *ccGuard) Park() error { + g.mu.Lock() + defer g.mu.Unlock() + + if g.cc == nil { + return xerrors.WithStackTrace(errClosedConnection) + } + + g.s = Parked + + return nil +} diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 66278bb94..0962c531b 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -3,13 +3,10 @@ package conn import ( "context" "fmt" - "sync" "sync/atomic" - "time" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" @@ -28,73 +25,51 @@ var ( // errClosedConnection specified error when connection are closed early errClosedConnection = xerrors.Wrap(fmt.Errorf("connection closed early")) - - // errUnavailableConnection specified error when connection are closed early - errUnavailableConnection = xerrors.Wrap(fmt.Errorf("connection unavailable")) ) -type Conn interface { - grpc.ClientConnInterface +type ( + Info interface { + Endpoint() endpoint.Endpoint + State() State + } + Conn interface { + grpc.ClientConnInterface + Info - Endpoint() endpoint.Endpoint + Ban() State + Unban() State + } +) - LastUsage() time.Time +type lazyConn struct { + config Config // ro access + endpoint endpoint.Endpoint // ro access - Ping(ctx context.Context) error - IsState(states ...State) bool - GetState() State - SetState(ctx context.Context, state State) State - Unban(ctx context.Context) State -} + cc *ccGuard -type conn struct { - mtx sync.RWMutex - config Config // ro access - cc *grpc.ClientConn - done chan struct{} - endpoint endpoint.Endpoint // ro access - closed bool - state atomic.Uint32 - lastUsage *lastUsage - onClose []func(*conn) - onTransportErrors []func(ctx context.Context, cc Conn, cause error) -} + inUse inUseGuard -func (c *conn) Address() string { - return c.endpoint.Address() -} + childStreams *cancelsGuard -func (c *conn) Ping(ctx context.Context) error { - cc, err := c.realConn(ctx) - if err != nil { - return c.wrapError(err) - } - if !isAvailable(cc) { - return c.wrapError(errUnavailableConnection) - } + lastUsage *lastUsageGuard - return nil + onClose []func(*lazyConn) + onTransportErrors []func(ctx context.Context, cc Conn, cause error) } -func (c *conn) LastUsage() time.Time { - c.mtx.RLock() - defer c.mtx.RUnlock() - - return c.lastUsage.Get() +func (c *lazyConn) Ban() State { + return c.cc.Ban() } -func (c *conn) IsState(states ...State) bool { - state := State(c.state.Load()) - for _, s := range states { - if s == state { - return true - } - } +func (c *lazyConn) Unban() State { + return c.cc.Unban() +} - return false +func (c *lazyConn) Address() string { + return c.endpoint.Address() } -func (c *conn) NodeID() uint32 { +func (c *lazyConn) NodeID() uint32 { if c != nil { return c.endpoint.NodeID() } @@ -102,37 +77,31 @@ func (c *conn) NodeID() uint32 { return 0 } -func (c *conn) park(ctx context.Context) (err error) { +func (c *lazyConn) park(ctx context.Context) (finalErr error) { onDone := trace.DriverOnConnPark( c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).park"), + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*lazyConn).park"), c.Endpoint(), ) defer func() { - onDone(err) + onDone(finalErr) }() - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.closed { + locked, unlock := c.inUse.TryLock() + if !locked { return nil } + defer unlock() - if c.cc == nil { - return nil - } - - err = c.close(ctx) - + err := c.cc.Park() if err != nil { - return c.wrapError(err) + return xerrors.WithStackTrace(err) } return nil } -func (c *conn) Endpoint() endpoint.Endpoint { +func (c *lazyConn) Endpoint() endpoint.Endpoint { if c != nil { return c.endpoint } @@ -140,164 +109,42 @@ func (c *conn) Endpoint() endpoint.Endpoint { return nil } -func (c *conn) SetState(ctx context.Context, s State) State { - return c.setState(ctx, s) -} - -func (c *conn) setState(ctx context.Context, s State) State { - if state := State(c.state.Swap(uint32(s))); state != s { - trace.DriverOnConnStateChange( - c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).setState"), - c.endpoint.Copy(), state, - )(s) - } - - return s -} - -func (c *conn) Unban(ctx context.Context) State { - var newState State - c.mtx.RLock() - cc := c.cc - c.mtx.RUnlock() - if isAvailable(cc) { - newState = Online - } else { - newState = Offline - } - - c.setState(ctx, newState) - - return newState -} - -func (c *conn) GetState() (s State) { - return State(c.state.Load()) -} - -func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { - if c.isClosed() { - return nil, c.wrapError(errClosedConnection) - } - - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.cc != nil { - return c.cc, nil - } - - if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) - defer cancel() - } - - onDone := trace.DriverOnConnDial( - c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).realConn"), - c.endpoint.Copy(), - ) - defer func() { - onDone(err) - }() - - // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme - // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. - address := "ydb:///" + c.endpoint.Address() - - cc, err = grpc.DialContext(ctx, address, append( - []grpc.DialOption{ - grpc.WithStatsHandler(statsHandler{}), - }, c.config.GrpcDialOptions()..., - )...) - if err != nil { - if xerrors.IsContextError(err) { - return nil, xerrors.WithStackTrace(err) - } - - defer func() { - c.onTransportError(ctx, err) - }() - - err = xerrors.Transport(err, - xerrors.WithAddress(address), - ) - - return nil, c.wrapError( - xerrors.Retryable(err, - xerrors.WithName("realConn"), - ), - ) - } - - c.cc = cc - c.setState(ctx, Online) - - return c.cc, nil -} - -func (c *conn) onTransportError(ctx context.Context, cause error) { +func (c *lazyConn) onTransportError(ctx context.Context, cause error) { for _, onTransportError := range c.onTransportErrors { onTransportError(ctx, c, cause) } } -func isAvailable(raw *grpc.ClientConn) bool { - return raw != nil && raw.GetState() == connectivity.Ready -} - -// conn must be locked -func (c *conn) close(ctx context.Context) (err error) { - if c.cc == nil { - return nil - } - err = c.cc.Close() - c.cc = nil - c.setState(ctx, Offline) - - return c.wrapError(err) -} - -func (c *conn) isClosed() bool { - c.mtx.RLock() - defer c.mtx.RUnlock() - - return c.closed +func (c *lazyConn) State() State { + return c.cc.State() } -func (c *conn) Close(ctx context.Context) (err error) { - c.mtx.Lock() - defer c.mtx.Unlock() - - if c.closed { - return nil - } - +func (c *lazyConn) Close(ctx context.Context) (finalErr error) { onDone := trace.DriverOnConnClose( c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Close"), + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*lazyConn).Close"), c.Endpoint(), ) defer func() { - onDone(err) + onDone(finalErr) }() - c.closed = true - - err = c.close(ctx) - - c.setState(ctx, Destroyed) + defer func() { + for _, onClose := range c.onClose { + onClose(c) + } + c.inUse.Stop() + }() - for _, onClose := range c.onClose { - onClose(c) + err := c.cc.Close(ctx) + if err != nil { + return xerrors.WithStackTrace(err) } - return c.wrapError(err) + return nil } -func (c *conn) Invoke( +func (c *lazyConn) Invoke( ctx context.Context, method string, req interface{}, @@ -310,7 +157,7 @@ func (c *conn) Invoke( useWrapping = UseWrapping(ctx) onDone = trace.DriverOnConnInvoke( c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Invoke"), + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*lazyConn).Invoke"), c.endpoint, trace.Method(method), ) cc *grpc.ClientConn @@ -318,10 +165,16 @@ func (c *conn) Invoke( ) defer func() { meta.CallTrailerCallback(ctx, md) - onDone(err, issues, opID, c.GetState(), md) + onDone(err, issues, opID, c.State(), md) }() - cc, err = c.realConn(ctx) + locked, unlock := c.inUse.TryLock() + if !locked { + return xerrors.WithStackTrace(errClosedConnection) + } + defer unlock() + + cc, err = c.cc.Get(ctx) if err != nil { return c.wrapError(err) } @@ -337,28 +190,26 @@ func (c *conn) Invoke( ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) err = cc.Invoke(ctx, method, req, res, append(opts, grpc.Trailer(&md))...) - if err != nil { - if xerrors.IsContextError(err) { - return xerrors.WithStackTrace(err) - } - - defer func() { - c.onTransportError(ctx, err) - }() + if err != nil { //nolint:nestif + if xerrors.IsTransportError(err) { + defer func() { + c.onTransportError(ctx, err) + }() + + if useWrapping { + err = xerrors.Transport(err, + xerrors.WithAddress(c.Address()), + xerrors.WithTraceID(traceID), + ) + if sentMark.canRetry() { + return c.wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke"))) + } - if useWrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(c.Address()), - xerrors.WithTraceID(traceID), - ) - if sentMark.canRetry() { - return c.wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke"))) + return c.wrapError(err) } - - return c.wrapError(err) } - return err + return xerrors.WithStackTrace(err) } if o, ok := res.(response.Response); ok { @@ -387,28 +238,32 @@ func (c *conn) Invoke( } //nolint:funlen -func (c *conn) NewStream( +func (c *lazyConn) NewStream( ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption, -) (_ grpc.ClientStream, err error) { +) (_ grpc.ClientStream, finalErr error) { var ( onDone = trace.DriverOnConnNewStream( c.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).NewStream"), + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*lazyConn).NewStream"), c.endpoint.Copy(), trace.Method(method), ) useWrapping = UseWrapping(ctx) - cc *grpc.ClientConn - s grpc.ClientStream ) defer func() { - onDone(err, c.GetState()) + onDone(finalErr, c.State()) }() - cc, err = c.realConn(ctx) + locked, unlock := c.inUse.TryLock() + if !locked { + return nil, xerrors.WithStackTrace(errClosedConnection) + } + defer unlock() + + cc, err := c.cc.Get(ctx) if err != nil { return nil, c.wrapError(err) } @@ -423,29 +278,38 @@ func (c *conn) NewStream( ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) - s, err = cc.NewStream(ctx, desc, method, opts...) - if err != nil { - if xerrors.IsContextError(err) { - return nil, xerrors.WithStackTrace(err) + ctx, cancel := xcontext.WithCancel(ctx) + defer func() { + if finalErr != nil { + cancel() + } else { + c.childStreams.Remember(&cancel) } + }() - defer func() { - c.onTransportError(ctx, err) - }() + s, err := cc.NewStream(ctx, desc, method, append(opts, grpc.OnFinish(func(err error) { + cancel() + c.childStreams.Forget(&cancel) + }))...) + if err != nil { //nolint:nestif + if xerrors.IsTransportError(err) { + defer func() { + c.onTransportError(ctx, err) + }() + if useWrapping { + err = xerrors.Transport(err, + xerrors.WithAddress(c.Address()), + xerrors.WithTraceID(traceID), + ) + if sentMark.canRetry() { + return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream"))) + } - if useWrapping { - err = xerrors.Transport(err, - xerrors.WithAddress(c.Address()), - xerrors.WithTraceID(traceID), - ) - if sentMark.canRetry() { - return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream"))) + return s, c.wrapError(err) } - - return s, c.wrapError(err) } - return s, err + return nil, xerrors.WithStackTrace(err) } return &grpcClientStream{ @@ -461,7 +325,7 @@ func (c *conn) NewStream( }, nil } -func (c *conn) wrapError(err error) error { +func (c *lazyConn) wrapError(err error) error { if err == nil { return nil } @@ -470,10 +334,10 @@ func (c *conn) wrapError(err error) error { return xerrors.WithStackTrace(nodeErr, xerrors.WithSkipDepth(1)) } -type option func(c *conn) +type option func(c *lazyConn) -func withOnClose(onClose func(*conn)) option { - return func(c *conn) { +func withOnClose(onClose func(*lazyConn)) option { + return func(c *lazyConn) { if onClose != nil { c.onClose = append(c.onClose, onClose) } @@ -481,32 +345,94 @@ func withOnClose(onClose func(*conn)) option { } func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { - return func(c *conn) { + return func(c *lazyConn) { if onTransportError != nil { c.onTransportErrors = append(c.onTransportErrors, onTransportError) } } } -func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { - c := &conn{ - endpoint: e, - config: config, - done: make(chan struct{}), - lastUsage: newLastUsage(nil), +func dial( + ctx context.Context, + t *trace.Driver, + e endpoint.Info, + opts ...grpc.DialOption, +) (_ *grpc.ClientConn, finalErr error) { + onDone := trace.DriverOnConnDial( + t, &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.dial"), + e, + ) + defer func() { + onDone(finalErr) + }() + + // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme + // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. + address := "ydb:///" + e.Address() + + cc, err := grpc.DialContext(ctx, address, append( + []grpc.DialOption{ + grpc.WithStatsHandler(statsHandler{}), + }, opts..., + )...) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return cc, nil +} + +func newConn(e endpoint.Endpoint, config Config, opts ...option) *lazyConn { + c := &lazyConn{ + endpoint: e, + config: config, + lastUsage: newLastUsage(nil), + childStreams: newCancelsGuard(), + onClose: []func(*lazyConn){ + func(c *lazyConn) { + c.childStreams.Cancel() + c.inUse.Stop() + }, + }, } - c.state.Store(uint32(Created)) for _, opt := range opts { if opt != nil { opt(c) } } + c.cc = newCcGuard(func(ctx context.Context) (*grpc.ClientConn, error) { + if dialTimeout := c.config.DialTimeout(); dialTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = xcontext.WithTimeout(ctx, dialTimeout) + defer cancel() + } - return c -} + cc, err := dial(ctx, c.config.Trace(), c.endpoint.Copy(), c.config.GrpcDialOptions()...) + if err != nil { + if xerrors.IsTransportError(err) { + defer func() { + c.onTransportError(ctx, err) + }() + + return nil, xerrors.WithStackTrace( + xerrors.Retryable( + xerrors.Transport(err, + xerrors.WithAddress(c.endpoint.Address()), + ), + ), + ) + } + + return nil, xerrors.WithStackTrace( + xerrors.Retryable(err, xerrors.WithName("dial")), + ) + } -func New(e endpoint.Endpoint, config Config, opts ...option) Conn { - return newConn(e, config, opts...) + return cc, nil + }) + + return c } var _ stats.Handler = statsHandler{} diff --git a/internal/conn/grpc_client_stream.go b/internal/conn/grpc_client_stream.go index 32377e5ab..99326c51e 100644 --- a/internal/conn/grpc_client_stream.go +++ b/internal/conn/grpc_client_stream.go @@ -17,7 +17,7 @@ import ( type grpcClientStream struct { grpc.ClientStream ctx context.Context - c *conn + c *lazyConn wrapping bool traceID string sentMark *modificationMark @@ -32,6 +32,12 @@ func (s *grpcClientStream) CloseSend() (err error) { onDone(err) }() + locked, unlock := s.c.inUse.TryLock() + if !locked { + return xerrors.WithStackTrace(errClosedConnection) + } + defer unlock() + stop := s.c.lastUsage.Start() defer stop() @@ -66,6 +72,12 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) { onDone(err) }() + locked, unlock := s.c.inUse.TryLock() + if !locked { + return xerrors.WithStackTrace(errClosedConnection) + } + defer unlock() + stop := s.c.lastUsage.Start() defer stop() @@ -108,6 +120,12 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { onDone(err) }() + locked, unlock := s.c.inUse.TryLock() + if !locked { + return xerrors.WithStackTrace(errClosedConnection) + } + defer unlock() + stop := s.c.lastUsage.Start() defer stop() diff --git a/internal/conn/in_use_quard.go b/internal/conn/in_use_quard.go new file mode 100644 index 000000000..d1bd4b1b8 --- /dev/null +++ b/internal/conn/in_use_quard.go @@ -0,0 +1,35 @@ +package conn + +import "sync" + +type inUseGuard struct { + usages sync.WaitGroup + mu sync.Mutex + stopped bool +} + +func (g *inUseGuard) TryLock() (locked bool, unlock func()) { + g.mu.Lock() + defer g.mu.Unlock() + + if g.stopped { + return false, nil + } + + g.usages.Add(1) + + return true, sync.OnceFunc(func() { + g.usages.Done() + }) +} + +func (g *inUseGuard) Stop() { + g.mu.Lock() + defer g.mu.Unlock() + + g.stopped = true + + g.usages.Wait() + + g.stopped = true +} diff --git a/internal/conn/in_use_quard_test.go b/internal/conn/in_use_quard_test.go new file mode 100644 index 000000000..e49e53f3e --- /dev/null +++ b/internal/conn/in_use_quard_test.go @@ -0,0 +1,52 @@ +package conn + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" +) + +func TestInUseGuard(t *testing.T) { + xtest.TestManyTimes(t, func(t testing.TB) { + g := &inUseGuard{} + ch := make(chan struct{}) + unlockFuncs := make([]func(), 10) + for i := range unlockFuncs { + locked, unlock := g.TryLock() + require.True(t, locked) + require.NotNil(t, unlock) + unlockFuncs[i] = func() { + <-ch + unlock() + } + } + waitStop := make(chan struct{}) + go func() { + defer func() { + close(waitStop) + }() + g.Stop() + }() + for i := range unlockFuncs { + go func(i int) { + unlockFuncs[i]() + }(i) + } + for range unlockFuncs { + select { + case <-waitStop: + require.Fail(t, "unexpected stop signal") + case ch <- struct{}{}: + } + } + close(ch) + select { + case <-waitStop: + case <-time.After(time.Second): + require.Fail(t, "not stopped after 1 second") + } + }) +} diff --git a/internal/conn/last_usage.go b/internal/conn/last_usage.go deleted file mode 100644 index b0ca293a9..000000000 --- a/internal/conn/last_usage.go +++ /dev/null @@ -1,47 +0,0 @@ -package conn - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/jonboulle/clockwork" -) - -type lastUsage struct { - locks atomic.Int64 - t atomic.Pointer[time.Time] - clock clockwork.Clock -} - -func newLastUsage(clock clockwork.Clock) *lastUsage { - if clock == nil { - clock = clockwork.NewRealClock() - } - now := clock.Now() - usage := &lastUsage{ - clock: clock, - } - usage.t.Store(&now) - - return usage -} - -func (l *lastUsage) Get() time.Time { - if l.locks.Load() == 0 { - return *l.t.Load() - } - - return l.clock.Now() -} - -func (l *lastUsage) Start() (stop func()) { - l.locks.Add(1) - - return sync.OnceFunc(func() { - if l.locks.Add(-1) == 0 { - now := l.clock.Now() - l.t.Store(&now) - } - }) -} diff --git a/internal/conn/last_usage_guard.go b/internal/conn/last_usage_guard.go new file mode 100644 index 000000000..0e169eec0 --- /dev/null +++ b/internal/conn/last_usage_guard.go @@ -0,0 +1,50 @@ +package conn + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/jonboulle/clockwork" +) + +type lastUsageGuard struct { + locks atomic.Int64 + t atomic.Pointer[time.Time] + clock clockwork.Clock +} + +func newLastUsage(clock clockwork.Clock) *lastUsageGuard { + if clock == nil { + clock = clockwork.NewRealClock() + } + + now := clock.Now() + + usage := &lastUsageGuard{ + clock: clock, + } + + usage.t.Store(&now) + + return usage +} + +func (g *lastUsageGuard) Get() time.Time { + if g.locks.Load() == 0 { + return *g.t.Load() + } + + return g.clock.Now() +} + +func (g *lastUsageGuard) Start() (stop func()) { + g.locks.Add(1) + + return sync.OnceFunc(func() { + if g.locks.Add(-1) == 0 { + now := g.clock.Now() + g.t.Store(&now) + } + }) +} diff --git a/internal/conn/last_usage_test.go b/internal/conn/last_usage_guard_test.go similarity index 96% rename from internal/conn/last_usage_test.go rename to internal/conn/last_usage_guard_test.go index b7c79695e..b163e9ef2 100644 --- a/internal/conn/last_usage_test.go +++ b/internal/conn/last_usage_guard_test.go @@ -12,7 +12,7 @@ func Test_lastUsage_Lock(t *testing.T) { t.Run("NowFromLocked", func(t *testing.T) { start := time.Unix(0, 0) clock := clockwork.NewFakeClockAt(start) - lu := &lastUsage{ + lu := &lastUsageGuard{ clock: clock, } lu.t.Store(&start) @@ -33,7 +33,7 @@ func Test_lastUsage_Lock(t *testing.T) { t.Run("UpdateAfterLastUnlock", func(t *testing.T) { start := time.Unix(0, 0) clock := clockwork.NewFakeClockAt(start) - lu := &lastUsage{ + lu := &lastUsageGuard{ clock: clock, } lu.t.Store(&start) @@ -71,7 +71,7 @@ func Test_lastUsage_Lock(t *testing.T) { t.Run("DeferRelease", func(t *testing.T) { start := time.Unix(0, 0) clock := clockwork.NewFakeClockAt(start) - lu := &lastUsage{ + lu := &lastUsageGuard{ clock: clock, } lu.t.Store(&start) diff --git a/internal/conn/pool.go b/internal/conn/pool.go index 6115c612f..0bb113ca0 100644 --- a/internal/conn/pool.go +++ b/internal/conn/pool.go @@ -2,14 +2,13 @@ package conn import ( "context" - "sync" "sync/atomic" "time" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" grpcCodes "google.golang.org/grpc/codes" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" @@ -28,7 +27,7 @@ type Pool struct { config Config mtx xsync.RWMutex opts []grpc.DialOption - conns map[connsKey]*conn + conns map[connsKey]*lazyConn done chan struct{} } @@ -38,7 +37,7 @@ func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { var ( address = endpoint.Address() - cc *conn + cc *lazyConn has bool ) @@ -60,7 +59,7 @@ func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { return cc } -func (p *Pool) remove(c *conn) { +func (p *Pool) remove(c *lazyConn) { p.mtx.Lock() defer p.mtx.Unlock() delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()}) @@ -115,11 +114,11 @@ func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { trace.DriverOnConnBan( p.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Ban"), - e, cc.GetState(), cause, - )(cc.SetState(ctx, Banned)) + e, cc.State(), cause, + )(cc.Ban()) } -func (p *Pool) Allow(ctx context.Context, cc Conn) { +func (p *Pool) Unban(ctx context.Context, cc Conn) { if p.isClosed() { return } @@ -134,22 +133,29 @@ func (p *Pool) Allow(ctx context.Context, cc Conn) { return } - trace.DriverOnConnAllow( + trace.DriverOnConnUnban( p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Allow"), - e, cc.GetState(), - )(cc.Unban(ctx)) + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Unban"), + e, cc.State(), + )(cc.Unban()) } -func (p *Pool) Take(context.Context) error { +func (p *Pool) Attach(ctx context.Context) (finalErr error) { + onDone := trace.DriverOnPoolAttach(p.config.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Attach"), + ) + defer func() { + onDone(finalErr) + }() + atomic.AddInt64(&p.usages, 1) return nil } -func (p *Pool) Release(ctx context.Context) (finalErr error) { +func (p *Pool) Detach(ctx context.Context) (finalErr error) { onDone := trace.DriverOnPoolRelease(p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Release"), + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Detach"), ) defer func() { onDone(finalErr) @@ -161,38 +167,18 @@ func (p *Pool) Release(ctx context.Context) (finalErr error) { close(p.done) - var conns []closer.Closer + var g errgroup.Group p.mtx.WithRLock(func() { - conns = make([]closer.Closer, 0, len(p.conns)) - for _, c := range p.conns { - conns = append(conns, c) + for key := range p.conns { + conn := p.conns[key] + g.Go(func() error { + return conn.Close(ctx) + }) } }) - var ( - errCh = make(chan error, len(conns)) - wg sync.WaitGroup - ) - - wg.Add(len(conns)) - for _, c := range conns { - go func(c closer.Closer) { - defer wg.Done() - if err := c.Close(ctx); err != nil { - errCh <- err - } - }(c) - } - wg.Wait() - close(errCh) - - issues := make([]error, 0, len(conns)) - for err := range errCh { - issues = append(issues, err) - } - - if len(issues) > 0 { - return xerrors.WithStackTrace(xerrors.NewWithIssues("connection pool close failed", issues...)) + if err := g.Wait(); err != nil { + return xerrors.WithStackTrace(err) } return nil @@ -207,8 +193,8 @@ func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { return case <-ticker.C: for _, c := range p.collectConns() { - if time.Since(c.LastUsage()) > ttl { - switch c.GetState() { + if time.Since(c.lastUsage.Get()) > ttl { + switch c.State() { case Online, Banned: _ = c.park(ctx) default: @@ -220,10 +206,10 @@ func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { } } -func (p *Pool) collectConns() []*conn { +func (p *Pool) collectConns() []*lazyConn { p.mtx.RLock() defer p.mtx.RUnlock() - conns := make([]*conn, 0, len(p.conns)) + conns := make([]*lazyConn, 0, len(p.conns)) for _, c := range p.conns { conns = append(conns, c) } @@ -241,7 +227,7 @@ func NewPool(ctx context.Context, config Config) *Pool { usages: 1, config: config, opts: config.GrpcDialOptions(), - conns: make(map[connsKey]*conn), + conns: make(map[connsKey]*lazyConn), done: make(chan struct{}), } diff --git a/internal/conn/state.go b/internal/conn/state.go index e0c67f73e..eaf14dc0b 100644 --- a/internal/conn/state.go +++ b/internal/conn/state.go @@ -6,9 +6,10 @@ const ( Unknown = State(iota) Created Online + Parked Banned Offline - Destroyed + Closed ) func (s State) Code() int { @@ -21,12 +22,14 @@ func (s State) String() string { return "created" case Online: return "online" + case Parked: + return "parked" case Banned: return "banned" case Offline: return "offline" - case Destroyed: - return "destroyed" + case Closed: + return "closed" default: return "unknown" } @@ -34,7 +37,7 @@ func (s State) String() string { func (s State) IsValid() bool { switch s { - case Online, Offline, Banned: + case Online, Offline, Banned, Parked: return true default: return false diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 823d721d0..cc3daa9fd 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -7,6 +7,8 @@ import ( ) type Info interface { + fmt.Stringer + NodeID() uint32 Address() string LocalDC() bool @@ -18,7 +20,6 @@ type Info interface { type Endpoint interface { Info - String() string Copy() Endpoint Touch(opts ...Option) } diff --git a/internal/mock/conn.go b/internal/mock/conn.go deleted file mode 100644 index ac57f6c41..000000000 --- a/internal/mock/conn.go +++ /dev/null @@ -1,125 +0,0 @@ -package mock - -import ( - "context" - "time" - - "google.golang.org/grpc" - - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" -) - -type Conn struct { - PingErr error - AddrField string - LocationField string - NodeIDField uint32 - State conn.State - LocalDCField bool -} - -func (c *Conn) Invoke( - ctx context.Context, - method string, - args interface{}, - reply interface{}, - opts ...grpc.CallOption, -) error { - panic("not implemented in mock") -} - -func (c *Conn) NewStream(ctx context.Context, - desc *grpc.StreamDesc, method string, - opts ...grpc.CallOption, -) (grpc.ClientStream, error) { - panic("not implemented in mock") -} - -func (c *Conn) Endpoint() endpoint.Endpoint { - return &Endpoint{ - AddrField: c.AddrField, - LocalDCField: c.LocalDCField, - LocationField: c.LocationField, - NodeIDField: c.NodeIDField, - } -} - -func (c *Conn) LastUsage() time.Time { - panic("not implemented in mock") -} - -func (c *Conn) Park(ctx context.Context) (err error) { - panic("not implemented in mock") -} - -func (c *Conn) Ping(ctx context.Context) error { - return c.PingErr -} - -func (c *Conn) IsState(states ...conn.State) bool { - panic("not implemented in mock") -} - -func (c *Conn) GetState() conn.State { - return c.State -} - -func (c *Conn) SetState(ctx context.Context, state conn.State) conn.State { - c.State = state - - return c.State -} - -func (c *Conn) Unban(ctx context.Context) conn.State { - c.SetState(ctx, conn.Online) - - return conn.Online -} - -type Endpoint struct { - AddrField string - LocationField string - NodeIDField uint32 - LocalDCField bool -} - -func (e *Endpoint) Choose(bool) { -} - -func (e *Endpoint) NodeID() uint32 { - return e.NodeIDField -} - -func (e *Endpoint) Address() string { - return e.AddrField -} - -func (e *Endpoint) LocalDC() bool { - return e.LocalDCField -} - -func (e *Endpoint) Location() string { - return e.LocationField -} - -func (e *Endpoint) LastUpdated() time.Time { - panic("not implemented in mock") -} - -func (e *Endpoint) LoadFactor() float32 { - panic("not implemented in mock") -} - -func (e *Endpoint) String() string { - panic("not implemented in mock") -} - -func (e *Endpoint) Copy() endpoint.Endpoint { - c := *e - - return &c -} - -func (e *Endpoint) Touch(opts ...endpoint.Option) { -} diff --git a/internal/mock/conn_info.go b/internal/mock/conn_info.go new file mode 100644 index 000000000..9f84a8aa2 --- /dev/null +++ b/internal/mock/conn_info.go @@ -0,0 +1,28 @@ +package mock + +import ( + "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" +) + +type ConnInfo struct { + EndpointAddrField string + EndpointLocationField string + EndpointNodeIDField uint32 + EndpointLocalDCField bool + + ConnState conn.State +} + +func (c *ConnInfo) State() conn.State { + return c.ConnState +} + +func (c *ConnInfo) Endpoint() endpoint.Endpoint { + return &Endpoint{ + AddrField: c.EndpointAddrField, + LocalDCField: c.EndpointLocalDCField, + LocationField: c.EndpointLocationField, + NodeIDField: c.EndpointNodeIDField, + } +} diff --git a/internal/mock/endpoint.go b/internal/mock/endpoint.go new file mode 100644 index 000000000..24aa51de9 --- /dev/null +++ b/internal/mock/endpoint.go @@ -0,0 +1,54 @@ +package mock + +import ( + "time" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" +) + +type Endpoint struct { + AddrField string + LocationField string + NodeIDField uint32 + LocalDCField bool +} + +func (e *Endpoint) Choose(bool) { +} + +func (e *Endpoint) NodeID() uint32 { + return e.NodeIDField +} + +func (e *Endpoint) Address() string { + return e.AddrField +} + +func (e *Endpoint) LocalDC() bool { + return e.LocalDCField +} + +func (e *Endpoint) Location() string { + return e.LocationField +} + +func (e *Endpoint) LastUpdated() time.Time { + panic("not implemented in mock") +} + +func (e *Endpoint) LoadFactor() float32 { + panic("not implemented in mock") +} + +func (e *Endpoint) String() string { + panic("not implemented in mock") +} + +func (e *Endpoint) Copy() endpoint.Endpoint { + c := *e + + return &c +} + +func (e *Endpoint) Touch(opts ...endpoint.Option) { +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index b7a5e8a3f..3368aadc8 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -4,6 +4,8 @@ import ( "context" "time" + "golang.org/x/sync/errgroup" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/pool/stats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" @@ -465,20 +467,16 @@ func (p *Pool[PT, T]) Close(ctx context.Context) (finalErr error) { p.mu.Lock() defer p.mu.Unlock() - errs := make([]error, 0, len(p.index)) - + var g errgroup.Group for item := range p.index { - if err := item.Close(ctx); err != nil { - errs = append(errs, err) - } + item := item + g.Go(func() error { + return item.Close(ctx) + }) } - - switch len(errs) { - case 0: - return nil - case 1: - return errs[0] - default: - return xerrors.Join(errs...) + if err := g.Wait(); err != nil { + return xerrors.WithStackTrace(err) } + + return nil } diff --git a/internal/table/client.go b/internal/table/client.go index 20dd9cb6e..03e1eeb58 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -537,10 +537,6 @@ func (c *Client) internalPoolWaitFromCh(ctx context.Context, t *trace.Table) (s // errClosedClient. // If Client is overflow calls s.Close(ctx) and returns // errSessionPoolOverflow. -// -// Note that Put() must be called only once after being created or received by -// Get() or Take() calls. In other way it will produce unexpected behavior or -// panic. func (c *Client) Put(ctx context.Context, s *session) (err error) { onDone := trace.TableOnPoolPut(c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/table.(*Client).Put"), diff --git a/internal/xerrors/pessimized_error_test.go b/internal/xerrors/pessimized_error_test.go deleted file mode 100644 index 403587849..000000000 --- a/internal/xerrors/pessimized_error_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package xerrors - -import ( - "context" - "errors" - "fmt" - "testing" - - grpcCodes "google.golang.org/grpc/codes" - grpcStatus "google.golang.org/grpc/status" -) - -func TestMustPessimizeEndpoint(t *testing.T) { - for _, test := range []struct { - error error - pessimize bool - }{ - { - error: Transport(grpcStatus.Error(grpcCodes.Canceled, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Unknown, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.InvalidArgument, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.DeadlineExceeded, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.NotFound, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.AlreadyExists, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.PermissionDenied, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.ResourceExhausted, "")), - pessimize: false, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.FailedPrecondition, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Aborted, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.OutOfRange, "")), - pessimize: false, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Unimplemented, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Internal, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Unavailable, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.DataLoss, "")), - pessimize: true, - }, - { - error: Transport(grpcStatus.Error(grpcCodes.Unauthenticated, "")), - pessimize: true, - }, - { - error: context.Canceled, - pessimize: false, - }, - { - error: context.DeadlineExceeded, - pessimize: false, - }, - { - error: fmt.Errorf("user error"), - pessimize: false, - }, - } { - err := errors.Unwrap(test.error) - if err == nil { - err = test.error - } - t.Run(err.Error(), func(t *testing.T) { - pessimize := MustPessimizeEndpoint(test.error) - if pessimize != test.pessimize { - t.Errorf("unexpected pessimization status for error `%v`: %t, exp: %t", test.error, pessimize, test.pessimize) - } - }) - } -} diff --git a/internal/xerrors/transport.go b/internal/xerrors/transport.go index b66f7735c..1fa8c0858 100644 --- a/internal/xerrors/transport.go +++ b/internal/xerrors/transport.go @@ -163,7 +163,7 @@ func Transport(err error, opts ...teOpt) error { return te } -func MustPessimizeEndpoint(err error, codes ...grpcCodes.Code) bool { +func MustBanConn(err error, codes ...grpcCodes.Code) bool { switch { case err == nil: return false diff --git a/internal/xerrors/transport_test.go b/internal/xerrors/transport_test.go index d94aa8ee4..527a24269 100644 --- a/internal/xerrors/transport_test.go +++ b/internal/xerrors/transport_test.go @@ -1,6 +1,8 @@ package xerrors import ( + "context" + "errors" "fmt" "testing" @@ -198,3 +200,98 @@ func TestTransportErrorName(t *testing.T) { }) } } + +func TestMustBanConn(t *testing.T) { + for _, test := range []struct { + error error + pessimize bool + }{ + { + error: Transport(grpcStatus.Error(grpcCodes.Canceled, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Unknown, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.InvalidArgument, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.DeadlineExceeded, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.NotFound, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.AlreadyExists, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.PermissionDenied, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.ResourceExhausted, "")), + pessimize: false, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.FailedPrecondition, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Aborted, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.OutOfRange, "")), + pessimize: false, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Unimplemented, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Internal, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Unavailable, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.DataLoss, "")), + pessimize: true, + }, + { + error: Transport(grpcStatus.Error(grpcCodes.Unauthenticated, "")), + pessimize: true, + }, + { + error: context.Canceled, + pessimize: false, + }, + { + error: context.DeadlineExceeded, + pessimize: false, + }, + { + error: fmt.Errorf("user error"), + pessimize: false, + }, + } { + err := errors.Unwrap(test.error) + if err == nil { + err = test.error + } + t.Run(err.Error(), func(t *testing.T) { + pessimize := MustBanConn(test.error) + if pessimize != test.pessimize { + t.Errorf("unexpected pessimization status for error `%v`: %t, exp: %t", test.error, pessimize, test.pessimize) + } + }) + } +} diff --git a/log/driver.go b/log/driver.go index fb9c8b1a1..8596c8889 100644 --- a/log/driver.go +++ b/log/driver.go @@ -337,7 +337,7 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo ) } }, - OnConnAllow: func(info trace.DriverConnAllowStartInfo) func(trace.DriverConnAllowDoneInfo) { + OnConnUnban: func(info trace.DriverConnUnbanStartInfo) func(trace.DriverConnUnbanDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -348,7 +348,7 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo ) start := time.Now() - return func(info trace.DriverConnAllowDoneInfo) { + return func(info trace.DriverConnUnbanDoneInfo) { l.Log(ctx, "done", Stringer("endpoint", endpoint), latencyField(start), diff --git a/options.go b/options.go index 9734b9b75..f4aa5b9de 100644 --- a/options.go +++ b/options.go @@ -656,6 +656,6 @@ func withConnPool(pool *conn.Pool) Option { return func(ctx context.Context, c *Driver) error { c.pool = pool - return pool.Take(ctx) + return pool.Attach(ctx) } } diff --git a/tests/integration/discovery_test.go b/tests/integration/discovery_test.go index 511043ebd..b73ff4638 100644 --- a/tests/integration/discovery_test.go +++ b/tests/integration/discovery_test.go @@ -6,6 +6,7 @@ package integration import ( "context" "crypto/tls" + "github.com/stretchr/testify/require" "os" "testing" "time" @@ -26,7 +27,7 @@ func TestDiscovery(t *testing.T) { var ( userAgent = "connection user agent" requestType = "connection request type" - checkMedatada = func(ctx context.Context) { + checkMetadata = func(ctx context.Context) { md, has := metadata.FromOutgoingContext(ctx) if !has { t.Fatalf("no medatada") @@ -78,7 +79,7 @@ func TestDiscovery(t *testing.T) { invoker grpc.UnaryInvoker, opts ...grpc.CallOption, ) error { - checkMedatada(ctx) + checkMetadata(ctx) return invoker(ctx, method, req, reply, cc, opts...) }), grpc.WithStreamInterceptor(func( @@ -89,7 +90,7 @@ func TestDiscovery(t *testing.T) { streamer grpc.Streamer, opts ...grpc.CallOption, ) (grpc.ClientStream, error) { - checkMedatada(ctx) + checkMetadata(ctx) return streamer(ctx, desc, cc, method, opts...) }), ), @@ -105,29 +106,13 @@ func TestDiscovery(t *testing.T) { if err != nil { t.Fatal(err) } - defer func() { - // cleanup connection - if e := db.Close(ctx); e != nil { - t.Fatalf("db close failed: %+v", e) - } - }() - t.Run("discovery.Discover", func(t *testing.T) { - if endpoints, err := db.Discovery().Discover(ctx); err != nil { - t.Fatal(err) - } else { - t.Log(endpoints) - } - t.Run("wait", func(t *testing.T) { - t.Run("parking", func(t *testing.T) { - <-parking // wait for parking conn - t.Run("re-discover", func(t *testing.T) { - if endpoints, err := db.Discovery().Discover(ctx); err != nil { - t.Fatal(err) - } else { - t.Log(endpoints) - } - }) - }) - }) - }) + endpoints, err := db.Discovery().Discover(ctx) + require.NoError(t, err) + require.NotEmpty(t, endpoints) + <-parking // wait for parking conn + endpoints, err = db.Discovery().Discover(ctx) + require.NoError(t, err) + require.NotEmpty(t, endpoints) + err = db.Close(ctx) + require.NoError(t, err) } diff --git a/tests/slo/prometheus.yml b/tests/slo/prometheus.yml new file mode 100644 index 000000000..0c8f893aa --- /dev/null +++ b/tests/slo/prometheus.yml @@ -0,0 +1,7 @@ +scrape_configs: + - job_name: ydb-go-sdk + scrape_interval: 5s + static_configs: + - targets: +# - localhost:8000 + - docker.for.mac.host.internal:8080 \ No newline at end of file diff --git a/trace/driver.go b/trace/driver.go index e33dac555..225d7c829 100644 --- a/trace/driver.go +++ b/trace/driver.go @@ -21,7 +21,11 @@ type ( OnClose func(DriverCloseStartInfo) func(DriverCloseDoneInfo) // Pool of connections - OnPoolNew func(DriverConnPoolNewStartInfo) func(DriverConnPoolNewDoneInfo) + OnPoolNew func(DriverConnPoolNewStartInfo) func(DriverConnPoolNewDoneInfo) + OnPoolAttach func(DriverConnPoolAttachStartInfo) func(DriverConnPoolAttachDoneInfo) + OnPoolDetach func(DriverConnPoolDetachStartInfo) func(DriverConnPoolDetachDoneInfo) + + // Deprecated OnPoolRelease func(DriverConnPoolReleaseStartInfo) func(DriverConnPoolReleaseDoneInfo) // Resolver events @@ -36,7 +40,7 @@ type ( OnConnStreamCloseSend func(DriverConnStreamCloseSendStartInfo) func(DriverConnStreamCloseSendDoneInfo) OnConnDial func(DriverConnDialStartInfo) func(DriverConnDialDoneInfo) OnConnBan func(DriverConnBanStartInfo) func(DriverConnBanDoneInfo) - OnConnAllow func(DriverConnAllowStartInfo) func(DriverConnAllowDoneInfo) + OnConnUnban func(DriverConnUnbanStartInfo) func(DriverConnUnbanDoneInfo) OnConnPark func(DriverConnParkStartInfo) func(DriverConnParkDoneInfo) OnConnClose func(DriverConnCloseStartInfo) func(DriverConnCloseDoneInfo) @@ -264,7 +268,7 @@ type ( DriverConnBanDoneInfo struct { State ConnState } - DriverConnAllowStartInfo struct { + DriverConnUnbanStartInfo struct { // Context make available context in trace callback function. // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. @@ -274,7 +278,7 @@ type ( Endpoint EndpointInfo State ConnState } - DriverConnAllowDoneInfo struct { + DriverConnUnbanDoneInfo struct { State ConnState } DriverConnInvokeStartInfo struct { @@ -461,6 +465,28 @@ type ( DriverConnPoolReleaseDoneInfo struct { Error error } + DriverConnPoolAttachStartInfo struct { + // Context make available context in trace callback function. + // Pointer to context provide replacement of context in trace callback function. + // Warning: concurrent access to pointer on client side must be excluded. + // Safe replacement of context are provided only inside callback function + Context *context.Context + Call call + } + DriverConnPoolAttachDoneInfo struct { + Error error + } + DriverConnPoolDetachStartInfo struct { + // Context make available context in trace callback function. + // Pointer to context provide replacement of context in trace callback function. + // Warning: concurrent access to pointer on client side must be excluded. + // Safe replacement of context are provided only inside callback function + Context *context.Context + Call call + } + DriverConnPoolDetachDoneInfo struct { + Error error + } DriverCloseStartInfo struct { // Context make available context in trace callback function. // Pointer to context provide replacement of context in trace callback function. diff --git a/trace/driver_gtrace.go b/trace/driver_gtrace.go index 4e737f5db..a1f108d75 100644 --- a/trace/driver_gtrace.go +++ b/trace/driver_gtrace.go @@ -170,6 +170,76 @@ func (t *Driver) Compose(x *Driver, opts ...DriverComposeOption) *Driver { } } } + { + h1 := t.OnPoolAttach + h2 := x.OnPoolAttach + ret.OnPoolAttach = func(d DriverConnPoolAttachStartInfo) func(DriverConnPoolAttachDoneInfo) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + } + }() + } + var r, r1 func(DriverConnPoolAttachDoneInfo) + if h1 != nil { + r = h1(d) + } + if h2 != nil { + r1 = h2(d) + } + return func(d DriverConnPoolAttachDoneInfo) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + } + }() + } + if r != nil { + r(d) + } + if r1 != nil { + r1(d) + } + } + } + } + { + h1 := t.OnPoolDetach + h2 := x.OnPoolDetach + ret.OnPoolDetach = func(d DriverConnPoolDetachStartInfo) func(DriverConnPoolDetachDoneInfo) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + } + }() + } + var r, r1 func(DriverConnPoolDetachDoneInfo) + if h1 != nil { + r = h1(d) + } + if h2 != nil { + r1 = h2(d) + } + return func(d DriverConnPoolDetachDoneInfo) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + } + }() + } + if r != nil { + r(d) + } + if r1 != nil { + r1(d) + } + } + } + } { h1 := t.OnPoolRelease h2 := x.OnPoolRelease @@ -521,9 +591,9 @@ func (t *Driver) Compose(x *Driver, opts ...DriverComposeOption) *Driver { } } { - h1 := t.OnConnAllow - h2 := x.OnConnAllow - ret.OnConnAllow = func(d DriverConnAllowStartInfo) func(DriverConnAllowDoneInfo) { + h1 := t.OnConnUnban + h2 := x.OnConnUnban + ret.OnConnUnban = func(d DriverConnUnbanStartInfo) func(DriverConnUnbanDoneInfo) { if options.panicCallback != nil { defer func() { if e := recover(); e != nil { @@ -531,14 +601,14 @@ func (t *Driver) Compose(x *Driver, opts ...DriverComposeOption) *Driver { } }() } - var r, r1 func(DriverConnAllowDoneInfo) + var r, r1 func(DriverConnUnbanDoneInfo) if h1 != nil { r = h1(d) } if h2 != nil { r1 = h2(d) } - return func(d DriverConnAllowDoneInfo) { + return func(d DriverConnUnbanDoneInfo) { if options.panicCallback != nil { defer func() { if e := recover(); e != nil { @@ -932,6 +1002,36 @@ func (t *Driver) onPoolNew(d DriverConnPoolNewStartInfo) func(DriverConnPoolNewD } return res } +func (t *Driver) onPoolAttach(d DriverConnPoolAttachStartInfo) func(DriverConnPoolAttachDoneInfo) { + fn := t.OnPoolAttach + if fn == nil { + return func(DriverConnPoolAttachDoneInfo) { + return + } + } + res := fn(d) + if res == nil { + return func(DriverConnPoolAttachDoneInfo) { + return + } + } + return res +} +func (t *Driver) onPoolDetach(d DriverConnPoolDetachStartInfo) func(DriverConnPoolDetachDoneInfo) { + fn := t.OnPoolDetach + if fn == nil { + return func(DriverConnPoolDetachDoneInfo) { + return + } + } + res := fn(d) + if res == nil { + return func(DriverConnPoolDetachDoneInfo) { + return + } + } + return res +} func (t *Driver) onPoolRelease(d DriverConnPoolReleaseStartInfo) func(DriverConnPoolReleaseDoneInfo) { fn := t.OnPoolRelease if fn == nil { @@ -1082,16 +1182,16 @@ func (t *Driver) onConnBan(d DriverConnBanStartInfo) func(DriverConnBanDoneInfo) } return res } -func (t *Driver) onConnAllow(d DriverConnAllowStartInfo) func(DriverConnAllowDoneInfo) { - fn := t.OnConnAllow +func (t *Driver) onConnUnban(d DriverConnUnbanStartInfo) func(DriverConnUnbanDoneInfo) { + fn := t.OnConnUnban if fn == nil { - return func(DriverConnAllowDoneInfo) { + return func(DriverConnUnbanDoneInfo) { return } } res := fn(d) if res == nil { - return func(DriverConnAllowDoneInfo) { + return func(DriverConnUnbanDoneInfo) { return } } @@ -1281,6 +1381,28 @@ func DriverOnPoolNew(t *Driver, c *context.Context, call call) func() { res(p) } } +func DriverOnPoolAttach(t *Driver, c *context.Context, call call) func(error) { + var p DriverConnPoolAttachStartInfo + p.Context = c + p.Call = call + res := t.onPoolAttach(p) + return func(e error) { + var p DriverConnPoolAttachDoneInfo + p.Error = e + res(p) + } +} +func DriverOnPoolDetach(t *Driver, c *context.Context, call call) func(error) { + var p DriverConnPoolDetachStartInfo + p.Context = c + p.Call = call + res := t.onPoolDetach(p) + return func(e error) { + var p DriverConnPoolDetachDoneInfo + p.Error = e + res(p) + } +} func DriverOnPoolRelease(t *Driver, c *context.Context, call call) func(error) { var p DriverConnPoolReleaseStartInfo p.Context = c @@ -1407,15 +1529,15 @@ func DriverOnConnBan(t *Driver, c *context.Context, call call, endpoint Endpoint res(p) } } -func DriverOnConnAllow(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { - var p DriverConnAllowStartInfo +func DriverOnConnUnban(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { + var p DriverConnUnbanStartInfo p.Context = c p.Call = call p.Endpoint = endpoint p.State = state - res := t.onConnAllow(p) + res := t.onConnUnban(p) return func(state ConnState) { - var p DriverConnAllowDoneInfo + var p DriverConnUnbanDoneInfo p.State = state res(p) }