diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c5bd059a..4ca62809b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added `table/options.WithQueryService()` option for redirect `/Ydb.Table.V1.TableService/ExecuteDataQuery` call to `/Ydb.Query.V1.QueryService/ExecuteQuery` + ## v3.65.2 * Fixed data race using `log.WithNames` diff --git a/balancers/balancers.go b/balancers/balancers.go index 2fe525a53..81b7fd144 100644 --- a/balancers/balancers.go +++ b/balancers/balancers.go @@ -1,11 +1,11 @@ package balancers import ( + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "sort" "strings" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) @@ -28,8 +28,8 @@ func SingleConn() *balancerConfig.Config { type filterLocalDC struct{} -func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Conn) bool { - return c.Endpoint().Location() == info.SelfLocation +func (filterLocalDC) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return e.Location() == info.SelfLocation } func (filterLocalDC) String() string { @@ -58,8 +58,8 @@ func PreferLocalDCWithFallBack(balancer *balancerConfig.Config) *balancerConfig. type filterLocations []string -func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Conn) bool { - location := strings.ToUpper(c.Endpoint().Location()) +func (locations filterLocations) Allow(_ balancerConfig.Info, e endpoint.Info) bool { + location := strings.ToUpper(e.Location()) for _, l := range locations { if location == l { return true @@ -122,10 +122,10 @@ type Endpoint interface { LocalDC() bool } -type filterFunc func(info balancerConfig.Info, c conn.Conn) bool +type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool -func (p filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool { - return p(info, c) +func (p filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return p(info, e) } func (p filterFunc) String() string { @@ -135,8 +135,8 @@ func (p filterFunc) String() string { // Prefer creates balancer which use endpoints by filter // Balancer "balancer" defines balancing algorithm between endpoints selected with filter func Prefer(balancer *balancerConfig.Config, filter func(endpoint Endpoint) bool) *balancerConfig.Config { - balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { - return filter(c.Endpoint()) + balancer.Filter = filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return filter(e) }) return balancer diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index f69ec11e2..b2b92b608 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -128,7 +128,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { } } - b.applyDiscoveredEndpoints(ctx, endpoints, localDC) + b.applyEndpoints(ctx, endpoints, localDC) return nil } @@ -142,40 +142,40 @@ func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Con added = make([]trace.EndpointInfo, 0, len(previousConns)) dropped = make([]trace.EndpointInfo, 0, len(previousConns)) var ( - newestMap = make(map[string]struct{}, len(newestEndpoints)) + newestMap = make(map[string]endpoint.Endpoint, len(newestEndpoints)) previousMap = make(map[string]struct{}, len(previousConns)) ) sort.Slice(newestEndpoints, func(i, j int) bool { return newestEndpoints[i].Address() < newestEndpoints[j].Address() }) sort.Slice(previousConns, func(i, j int) bool { - return previousConns[i].Endpoint().Address() < previousConns[j].Endpoint().Address() + return previousConns[i].Address() < previousConns[j].Address() }) for _, e := range previousConns { - previousMap[e.Endpoint().Address()] = struct{}{} + previousMap[e.Address()] = struct{}{} } for _, e := range newestEndpoints { - nodes = append(nodes, e.Copy()) - newestMap[e.Address()] = struct{}{} + nodes = append(nodes, e) + newestMap[e.Address()] = e if _, has := previousMap[e.Address()]; !has { - added = append(added, e.Copy()) + added = append(added, e) } } for _, c := range previousConns { - if _, has := newestMap[c.Endpoint().Address()]; !has { - dropped = append(dropped, c.Endpoint().Copy()) + if e, has := newestMap[c.Address()]; !has { + dropped = append(dropped, e) } } return nodes, added, dropped } -func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []endpoint.Endpoint, localDC string) { +func (b *Balancer) applyEndpoints(ctx context.Context, endpoints []endpoint.Endpoint, localDC string) { var ( onDone = trace.DriverOnBalancerUpdate( b.driverConfig.Trace(), &ctx, stack.FunctionID( - "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), + "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyEndpoints"), b.config.DetectLocalDC, ) previousConns []conn.Conn @@ -186,10 +186,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end }() connections := endpointsToConnections(b.pool, endpoints) - for _, c := range connections { + connections.Each(func(c conn.Conn, e endpoint.Endpoint) { b.pool.Allow(ctx, c) - c.Endpoint().Touch() - } + }) info := balancerConfig.Info{SelfLocation: localDC} state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) @@ -260,7 +259,7 @@ func New( localDCDetector: detectLocalDC, } d := internalDiscovery.New(ctx, pool.Get( - endpoint.New(driverConfig.Endpoint()), + driverConfig.Endpoint(), ), discoveryConfig) b.discoveryClient = d @@ -272,8 +271,8 @@ func New( } if b.config.SingleConn { - b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{ - endpoint.New(driverConfig.Endpoint()), + b.applyEndpoints(ctx, []endpoint.Endpoint{ + endpoint.New(0, driverConfig.Endpoint(), ""), }, "") } else { // initialization of balancer state @@ -348,8 +347,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc if conn.UseWrapping(ctx) { if credentials.IsAccessError(err) { err = credentials.AccessError("no access", err, - credentials.WithAddress(cc.Endpoint().String()), - credentials.WithNodeID(cc.Endpoint().NodeID()), + credentials.WithAddress(cc.Address()), credentials.WithCredentials(b.driverConfig.Credentials()), ) } @@ -377,9 +375,9 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { ) defer func() { if err == nil { - onDone(c.Endpoint(), nil) + onDone(c.Address(), nil) } else { - onDone(nil, err) + onDone("", err) } }() @@ -408,11 +406,43 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { return c, nil } -func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { - conns := make([]conn.Conn, 0, len(endpoints)) - for _, e := range endpoints { - conns = append(conns, p.Get(e)) +type connByEndpoint struct { + conns map[string]conn.Conn + endpoints map[string]endpoint.Endpoint +} + +func (m *connByEndpoint) Get(address string) (conn.Conn, endpoint.Endpoint) { + return m.conns[address], m.endpoints[address] +} + +func (m *connByEndpoint) Each(visitor func(c conn.Conn, e endpoint.Endpoint)) { + for _, e := range m.endpoints { + visitor(m.conns[e.Address()], e) } +} + +func (m *connByEndpoint) Len() int { + return len(m.endpoints) +} +func (m *connByEndpoint) Conns() (conns []conn.Conn) { + conns = make([]conn.Conn, 0, len(m.conns)) + for _, c := range m.conns { + conns = append(conns, c) + } return conns } + +func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) *connByEndpoint { + m := &connByEndpoint{ + conns: make(map[string]conn.Conn, len(endpoints)), + endpoints: make(map[string]endpoint.Endpoint, len(endpoints)), + } + + for _, e := range endpoints { + m.conns[e.Address()] = p.Get(e.Address()) + m.endpoints[e.Address()] = e + } + + return m +} diff --git a/internal/balancer/config/routerconfig.go b/internal/balancer/config/routerconfig.go index 0d1eb6703..b1d0dce27 100644 --- a/internal/balancer/config/routerconfig.go +++ b/internal/balancer/config/routerconfig.go @@ -2,8 +2,8 @@ package config import ( "fmt" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" ) @@ -47,6 +47,6 @@ type Info struct { } type Filter interface { - Allow(info Info, c conn.Conn) bool + Allow(info Info, e endpoint.Info) bool String() string } diff --git a/internal/balancer/connections_state.go b/internal/balancer/connections_state.go index e9196ead7..d8abf9ba3 100644 --- a/internal/balancer/connections_state.go +++ b/internal/balancer/connections_state.go @@ -2,6 +2,7 @@ package balancer import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" @@ -19,7 +20,7 @@ type connectionsState struct { } func newConnectionsState( - conns []conn.Conn, + conns *connByEndpoint, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, @@ -31,7 +32,7 @@ func newConnectionsState( res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback) if allowFallback { - res.all = conns + res.all = conns.Conns() } else { res.all = res.prefer } @@ -115,40 +116,40 @@ func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned return nil, failedConns } -func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { - if len(conns) == 0 { +func connsToNodeIDMap(conns *connByEndpoint) (nodes map[uint32]conn.Conn) { + if conns.Len() == 0 { return nil } - nodes = make(map[uint32]conn.Conn, len(conns)) - for _, c := range conns { - nodes[c.Endpoint().NodeID()] = c - } + nodes = make(map[uint32]conn.Conn, conns.Len()) + conns.Each(func(c conn.Conn, e endpoint.Endpoint) { + nodes[e.NodeID()] = c + }) return nodes } func sortPreferConnections( - conns []conn.Conn, + conns *connByEndpoint, filter balancerConfig.Filter, info balancerConfig.Info, allowFallback bool, ) (prefer, fallback []conn.Conn) { if filter == nil { - return conns, nil + return conns.Conns(), nil } - prefer = make([]conn.Conn, 0, len(conns)) + prefer = make([]conn.Conn, 0, conns.Len()) if allowFallback { - fallback = make([]conn.Conn, 0, len(conns)) + fallback = make([]conn.Conn, 0, conns.Len()) } - for _, c := range conns { - if filter.Allow(info, c) { + conns.Each(func(c conn.Conn, e endpoint.Endpoint) { + if filter.Allow(info, e) { prefer = append(prefer, c) } else if allowFallback { fallback = append(fallback, c) } - } + }) return prefer, fallback } diff --git a/internal/balancer/connections_state_test.go b/internal/balancer/connections_state_test.go index b052b3933..b21defd51 100644 --- a/internal/balancer/connections_state_test.go +++ b/internal/balancer/connections_state_test.go @@ -2,6 +2,7 @@ package balancer import ( "context" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "strings" "testing" @@ -65,10 +66,10 @@ func TestConnsToNodeIDMap(t *testing.T) { } } -type filterFunc func(info balancerConfig.Info, c conn.Conn) bool +type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool -func (f filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool { - return f(info, c) +func (f filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { + return f(info, e) } func (f filterFunc) String() string { @@ -115,8 +116,8 @@ func TestSortPreferConnections(t *testing.T) { &mock.Conn{AddrField: "f2"}, }, allowFallback: false, - filter: filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { - return strings.HasPrefix(c.Endpoint().Address(), "t") + filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return strings.HasPrefix(e.Address(), "t") }), prefer: []conn.Conn{ &mock.Conn{AddrField: "t1"}, @@ -133,8 +134,8 @@ func TestSortPreferConnections(t *testing.T) { &mock.Conn{AddrField: "f2"}, }, allowFallback: true, - filter: filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool { - return strings.HasPrefix(c.Endpoint().Address(), "t") + filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { + return strings.HasPrefix(e.Address(), "t") }), prefer: []conn.Conn{ &mock.Conn{AddrField: "t1"}, @@ -190,7 +191,7 @@ func TestSelectRandomConnection(t *testing.T) { second := 0 for i := 0; i < 100; i++ { c, _ := s.selectRandomConnection(conns, false) - if c.Endpoint().Address() == "1" { + if c.Address() == "1" { first++ } else { second++ @@ -225,13 +226,13 @@ func TestSelectRandomConnection(t *testing.T) { for i := 0; i < 100; i++ { c, checkFailed := s.selectRandomConnection(conns, false) failed += checkFailed - switch c.Endpoint().Address() { + switch c.Address() { case "1": first++ case "2": second++ default: - t.Errorf(c.Endpoint().Address()) + t.Errorf(c.Address()) } } require.Equal(t, 100, first+second) @@ -286,8 +287,8 @@ func TestNewState(t *testing.T) { &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { - return info.SelfLocation == c.Endpoint().Location() + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() }), balancerConfig.Info{SelfLocation: "t"}, false), res: &connectionsState{ connByNodeID: map[uint32]conn.Conn{ @@ -314,8 +315,8 @@ func TestNewState(t *testing.T) { &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { - return info.SelfLocation == c.Endpoint().Location() + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() }), balancerConfig.Info{SelfLocation: "t"}, true), res: &connectionsState{ connByNodeID: map[uint32]conn.Conn{ @@ -347,8 +348,8 @@ func TestNewState(t *testing.T) { &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { - return info.SelfLocation == c.Endpoint().Location() + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return info.SelfLocation == e.Location() }), balancerConfig.Info{SelfLocation: "t"}, true), res: &connectionsState{ connByNodeID: map[uint32]conn.Conn{ @@ -412,8 +413,8 @@ func TestConnection(t *testing.T) { s := newConnectionsState([]conn.Conn{ &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, &mock.Conn{AddrField: "f2", State: conn.Banned, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { - return c.Endpoint().Location() == info.SelfLocation + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return e.Location() == info.SelfLocation }), balancerConfig.Info{}, true) preferred := 0 fallback := 0 @@ -421,7 +422,7 @@ func TestConnection(t *testing.T) { c, failed := s.GetConnection(context.Background()) require.NotNil(t, c) require.Equal(t, 2, failed) - if c.Endpoint().Address() == "t1" { + if c.Address() == "t1" { preferred++ } else { fallback++ @@ -435,8 +436,8 @@ func TestConnection(t *testing.T) { s := newConnectionsState([]conn.Conn{ &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { - return c.Endpoint().Location() == info.SelfLocation + }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { + return e.Location() == info.SelfLocation }), balancerConfig.Info{SelfLocation: "t"}, true) c, failed := s.GetConnection(context.Background()) require.Equal(t, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 8192e38e7..579d9188c 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -13,7 +13,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" "github.com/ydb-platform/ydb-go-sdk/v3/internal/response" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" @@ -37,7 +36,7 @@ var ( type Conn interface { grpc.ClientConnInterface - Endpoint() endpoint.Endpoint + Address() string LastUsage() time.Time @@ -53,7 +52,7 @@ type conn struct { config Config // ro access cc *grpc.ClientConn done chan struct{} - endpoint endpoint.Endpoint // ro access + address string closed bool state atomic.Uint32 childStreams *xcontext.CancelsGuard @@ -63,16 +62,16 @@ type conn struct { } func (c *conn) Address() string { - return c.endpoint.Address() + return c.address } func (c *conn) Ping(ctx context.Context) error { cc, err := c.realConn(ctx) if err != nil { - return c.wrapError(err) + return xerrors.WithStackTrace(err) } if !isAvailable(cc) { - return c.wrapError(errUnavailableConnection) + return xerrors.WithStackTrace(errUnavailableConnection) } return nil @@ -96,19 +95,11 @@ func (c *conn) IsState(states ...State) bool { return false } -func (c *conn) NodeID() uint32 { - if c != nil { - return c.endpoint.NodeID() - } - - return 0 -} - func (c *conn) park(ctx context.Context) (err error) { onDone := trace.DriverOnConnPark( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).park"), - c.Endpoint(), + c.address, ) defer func() { onDone(err) @@ -128,15 +119,7 @@ func (c *conn) park(ctx context.Context) (err error) { err = c.close(ctx) if err != nil { - return c.wrapError(err) - } - - return nil -} - -func (c *conn) Endpoint() endpoint.Endpoint { - if c != nil { - return c.endpoint + return xerrors.WithStackTrace(err) } return nil @@ -151,7 +134,7 @@ func (c *conn) setState(ctx context.Context, s State) State { trace.DriverOnConnStateChange( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).setState"), - c.endpoint.Copy(), state, + c.address, state, )(s) } @@ -180,7 +163,7 @@ func (c *conn) GetState() (s State) { func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { if c.isClosed() { - return nil, c.wrapError(errClosedConnection) + return nil, xerrors.WithStackTrace(errClosedConnection) } c.mtx.Lock() @@ -199,7 +182,7 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { onDone := trace.DriverOnConnDial( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).realConn"), - c.endpoint.Copy(), + c.address, ) defer func() { onDone(err) @@ -207,7 +190,7 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { // prepend "ydb" scheme for grpc dns-resolver to find the proper scheme // three slashes in "ydb:///" is ok. It needs for good parse scheme in grpc resolver. - address := "ydb:///" + c.endpoint.Address() + address := "ydb:///" + c.address cc, err = grpc.DialContext(ctx, address, append( []grpc.DialOption{ @@ -227,7 +210,7 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { xerrors.WithAddress(address), ) - return nil, c.wrapError( + return nil, xerrors.WithStackTrace( xerrors.Retryable(err, xerrors.WithName("realConn"), ), @@ -259,7 +242,7 @@ func (c *conn) close(ctx context.Context) (err error) { c.cc = nil c.setState(ctx, Offline) - return c.wrapError(err) + return xerrors.WithStackTrace(err) } func (c *conn) isClosed() bool { @@ -280,7 +263,7 @@ func (c *conn) Close(ctx context.Context) (err error) { onDone := trace.DriverOnConnClose( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Close"), - c.Endpoint(), + c.address, ) defer func() { onDone(err) @@ -296,7 +279,7 @@ func (c *conn) Close(ctx context.Context) (err error) { onClose(c) } - return c.wrapError(err) + return xerrors.WithStackTrace(err) } func (c *conn) Invoke( @@ -313,7 +296,7 @@ func (c *conn) Invoke( onDone = trace.DriverOnConnInvoke( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).Invoke"), - c.endpoint, trace.Method(method), + c.address, trace.Method(method), ) cc *grpc.ClientConn md = metadata.MD{} @@ -325,7 +308,7 @@ func (c *conn) Invoke( cc, err = c.realConn(ctx) if err != nil { - return c.wrapError(err) + return xerrors.WithStackTrace(err) } stop := c.lastUsage.Start() @@ -354,10 +337,10 @@ func (c *conn) Invoke( xerrors.WithTraceID(traceID), ) if sentMark.canRetry() { - return c.wrapError(xerrors.Retryable(err, xerrors.WithName("Invoke"))) + return xerrors.WithStackTrace(xerrors.Retryable(err, xerrors.WithName("Invoke"))) } - return c.wrapError(err) + return xerrors.WithStackTrace(err) } return err @@ -371,10 +354,10 @@ func (c *conn) Invoke( if useWrapping { switch { case !o.GetOperation().GetReady(): - return c.wrapError(errOperationNotReady) + return xerrors.WithStackTrace(errOperationNotReady) case o.GetOperation().GetStatus() != Ydb.StatusIds_SUCCESS: - return c.wrapError( + return xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(o.GetOperation()), xerrors.WithAddress(c.Address()), @@ -399,7 +382,7 @@ func (c *conn) NewStream( onDone = trace.DriverOnConnNewStream( c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*conn).NewStream"), - c.endpoint.Copy(), trace.Method(method), + c.address, trace.Method(method), ) useWrapping = UseWrapping(ctx) ) @@ -410,7 +393,7 @@ func (c *conn) NewStream( cc, err := c.realConn(ctx) if err != nil { - return nil, c.wrapError(err) + return nil, xerrors.WithStackTrace(err) } stop := c.lastUsage.Start() @@ -451,10 +434,10 @@ func (c *conn) NewStream( xerrors.WithTraceID(traceID), ) if sentMark.canRetry() { - return s, c.wrapError(xerrors.Retryable(err, xerrors.WithName("NewStream"))) + return s, xerrors.WithStackTrace(xerrors.Retryable(err, xerrors.WithName("NewStream"))) } - return s, c.wrapError(err) + return s, xerrors.WithStackTrace(err) } return s, err @@ -473,15 +456,6 @@ func (c *conn) NewStream( }, nil } -func (c *conn) wrapError(err error) error { - if err == nil { - return nil - } - nodeErr := newConnError(c.endpoint.NodeID(), c.endpoint.Address(), err) - - return xerrors.WithStackTrace(nodeErr, xerrors.WithSkipDepth(1)) -} - type option func(c *conn) func withOnClose(onClose func(*conn)) option { @@ -500,9 +474,9 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca } } -func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { +func newConn(address string, config Config, opts ...option) *conn { c := &conn{ - endpoint: e, + address: address, config: config, done: make(chan struct{}), lastUsage: xsync.NewLastUsage(), @@ -523,8 +497,8 @@ func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { return c } -func New(e endpoint.Endpoint, config Config, opts ...option) Conn { - return newConn(e, config, opts...) +func New(address string, config Config, opts ...option) Conn { + return newConn(address, config, opts...) } var _ stats.Handler = statsHandler{} diff --git a/internal/conn/error.go b/internal/conn/error.go deleted file mode 100644 index 216ac2f7b..000000000 --- a/internal/conn/error.go +++ /dev/null @@ -1,25 +0,0 @@ -package conn - -import "fmt" - -type connError struct { - nodeID uint32 - endpoint string - err error -} - -func newConnError(id uint32, endpoint string, err error) connError { - return connError{ - nodeID: id, - endpoint: endpoint, - err: err, - } -} - -func (n connError) Error() string { - return fmt.Sprintf("connError{node_id:%d,address:'%s'}: %v", n.nodeID, n.endpoint, n.err) -} - -func (n connError) Unwrap() error { - return n.err -} diff --git a/internal/conn/error_test.go b/internal/conn/error_test.go deleted file mode 100644 index 569a38a31..000000000 --- a/internal/conn/error_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package conn - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestNodeErrorError(t *testing.T) { - testErr := errors.New("test") - nodeErr := newConnError(1, "localhost:1234", testErr) - message := nodeErr.Error() - - require.Equal(t, "connError{node_id:1,address:'localhost:1234'}: test", message) -} - -func TestNodeErrorUnwrap(t *testing.T) { - testErr := errors.New("test") - nodeErr := newConnError(1, "asd", testErr) - - unwrapped := errors.Unwrap(nodeErr) - require.Equal(t, testErr, unwrapped) -} - -func TestNodeErrorIs(t *testing.T) { - testErr := errors.New("test") - testErr2 := errors.New("test2") - nodeErr := newConnError(1, "localhost:1234", testErr) - - require.ErrorIs(t, nodeErr, testErr) - require.NotErrorIs(t, nodeErr, testErr2) -} - -type testType1Error struct { - msg string -} - -func (t testType1Error) Error() string { - return "1 - " + t.msg -} - -type testType2Error struct { - msg string -} - -func (t testType2Error) Error() string { - return "2 - " + t.msg -} - -func TestNodeErrorAs(t *testing.T) { - testErr := testType1Error{msg: "test"} - nodeErr := newConnError(1, "localhost:1234", testErr) - - target := testType1Error{} - require.ErrorAs(t, nodeErr, &target) - require.Equal(t, testErr, target) - - target2 := testType2Error{} - require.False(t, errors.As(nodeErr, &target2)) -} diff --git a/internal/conn/grpc_client_stream.go b/internal/conn/grpc_client_stream.go index 32377e5ab..590d532c7 100644 --- a/internal/conn/grpc_client_stream.go +++ b/internal/conn/grpc_client_stream.go @@ -43,7 +43,7 @@ func (s *grpcClientStream) CloseSend() (err error) { } if s.wrapping { - return s.wrapError( + return xerrors.WithStackTrace( xerrors.Transport( err, xerrors.WithAddress(s.c.Address()), @@ -52,7 +52,7 @@ func (s *grpcClientStream) CloseSend() (err error) { ) } - return s.wrapError(err) + return xerrors.WithStackTrace(err) } return nil @@ -86,12 +86,12 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) { xerrors.WithTraceID(s.traceID), ) if s.sentMark.canRetry() { - return s.wrapError(xerrors.Retryable(err, + return xerrors.WithStackTrace(xerrors.Retryable(err, xerrors.WithName("SendMsg"), )) } - return s.wrapError(err) + return xerrors.WithStackTrace(err) } return err @@ -136,12 +136,12 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { xerrors.WithAddress(s.c.Address()), ) if s.sentMark.canRetry() { - return s.wrapError(xerrors.Retryable(err, + return xerrors.WithStackTrace(xerrors.Retryable(err, xerrors.WithName("RecvMsg"), )) } - return s.wrapError(err) + return xerrors.WithStackTrace(err) } return err @@ -150,7 +150,7 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { if s.wrapping { if operation, ok := m.(wrap.StreamOperationResponse); ok { if status := operation.GetStatus(); status != Ydb.StatusIds_SUCCESS { - return s.wrapError( + return xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(operation), xerrors.WithAddress(s.c.Address()), @@ -162,14 +162,3 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { return nil } - -func (s *grpcClientStream) wrapError(err error) error { - if err == nil { - return nil - } - - return xerrors.WithStackTrace( - newConnError(s.c.endpoint.NodeID(), s.c.endpoint.Address(), err), - xerrors.WithSkipDepth(1), - ) -} diff --git a/internal/conn/pool.go b/internal/conn/pool.go index 783b7a880..0248f9a78 100644 --- a/internal/conn/pool.go +++ b/internal/conn/pool.go @@ -10,7 +10,6 @@ import ( grpcCodes "google.golang.org/grpc/codes" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -18,44 +17,31 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -type connsKey struct { - address string - nodeID uint32 -} - type Pool struct { usages int64 config Config mtx xsync.RWMutex opts []grpc.DialOption - conns map[connsKey]*conn + conns map[string]*conn done chan struct{} } -func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { +func (p *Pool) Get(address string) Conn { p.mtx.Lock() defer p.mtx.Unlock() var ( - address = endpoint.Address() - cc *conn - has bool + cc *conn + has bool ) - key := connsKey{address, endpoint.NodeID()} - - if cc, has = p.conns[key]; has { + if cc, has = p.conns[address]; has { return cc } - cc = newConn( - endpoint, - p.config, - withOnClose(p.remove), - withOnTransportError(p.Ban), - ) + cc = newConn(address, p.config, withOnClose(p.remove), withOnTransportError(p.Ban)) - p.conns[key] = cc + p.conns[address] = cc return cc } @@ -63,7 +49,7 @@ func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { func (p *Pool) remove(c *conn) { p.mtx.Lock() defer p.mtx.Unlock() - delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()}) + delete(p.conns, c.address) } func (p *Pool) isClosed() bool { @@ -102,12 +88,10 @@ func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { return } - e := cc.Endpoint().Copy() - p.mtx.RLock() defer p.mtx.RUnlock() - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] + cc, ok := p.conns[cc.Address()] if !ok { return } @@ -115,7 +99,7 @@ func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { trace.DriverOnConnBan( p.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Ban"), - e, cc.GetState(), cause, + cc.Address(), cc.GetState(), cause, )(cc.SetState(ctx, Banned)) } @@ -124,12 +108,10 @@ func (p *Pool) Allow(ctx context.Context, cc Conn) { return } - e := cc.Endpoint().Copy() - p.mtx.RLock() defer p.mtx.RUnlock() - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] + cc, ok := p.conns[cc.Address()] if !ok { return } @@ -137,7 +119,7 @@ func (p *Pool) Allow(ctx context.Context, cc Conn) { trace.DriverOnConnAllow( p.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Allow"), - e, cc.GetState(), + cc.Address(), cc.GetState(), )(cc.Unban(ctx)) } @@ -241,7 +223,7 @@ func NewPool(ctx context.Context, config Config) *Pool { usages: 1, config: config, opts: config.GrpcDialOptions(), - conns: make(map[connsKey]*conn), + conns: make(map[string]*conn), done: make(chan struct{}), } diff --git a/internal/credentials/access_error.go b/internal/credentials/access_error.go index 777bc3d80..a962eda50 100644 --- a/internal/credentials/access_error.go +++ b/internal/credentials/access_error.go @@ -2,12 +2,10 @@ package credentials import ( "fmt" - "io" - "reflect" - "strconv" - "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" grpcCodes "google.golang.org/grpc/codes" + "io" + "reflect" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring" @@ -18,7 +16,6 @@ type authErrorOption interface { } var ( - _ authErrorOption = nodeIDAuthErrorOption(0) _ authErrorOption = addressAuthErrorOption("") _ authErrorOption = endpointAuthErrorOption("") _ authErrorOption = databaseAuthErrorOption("") @@ -57,17 +54,6 @@ func WithDatabase(database string) databaseAuthErrorOption { return databaseAuthErrorOption(database) } -type nodeIDAuthErrorOption uint32 - -func (id nodeIDAuthErrorOption) applyAuthErrorOption(w io.Writer) { - fmt.Fprint(w, "nodeID:") - fmt.Fprint(w, strconv.FormatUint(uint64(id), 10)) -} - -func WithNodeID(id uint32) authErrorOption { - return nodeIDAuthErrorOption(id) -} - type credentialsUnauthenticatedErrorOption struct { credentials Credentials } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index dfda2660c..5d9c62c81 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -5,6 +5,7 @@ import ( "io" "net" "strconv" + "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Discovery_V1" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" @@ -53,7 +54,7 @@ func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, e defer func() { nodes := make([]trace.EndpointInfo, 0, len(endpoints)) for _, e := range endpoints { - nodes = append(nodes, e.Copy()) + nodes = append(nodes, e) } onDone(location, nodes, err) }() @@ -84,10 +85,11 @@ func (c *Client) Discover(ctx context.Context) (endpoints []endpoint.Endpoint, e for _, e := range result.GetEndpoints() { if e.GetSsl() == c.config.Secure() { endpoints = append(endpoints, endpoint.New( + e.GetNodeId(), net.JoinHostPort(e.GetAddress(), strconv.Itoa(int(e.GetPort()))), - endpoint.WithLocation(e.GetLocation()), - endpoint.WithID(e.GetNodeId()), + e.GetLocation(), endpoint.WithLoadFactor(e.GetLoadFactor()), + endpoint.WithLastUpdated(time.Now()), endpoint.WithLocalDC(e.GetLocation() == location), endpoint.WithServices(e.GetService()), )) diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 37a889b81..76c25f9b6 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -2,7 +2,7 @@ package endpoint import ( "fmt" - "sync" + "sync/atomic" "time" ) @@ -24,70 +24,41 @@ type Endpoint interface { Info String() string - Copy() Endpoint Touch(opts ...Option) } type endpoint struct { //nolint:maligned - mu sync.RWMutex id uint32 address string location string services []string - loadFactor float32 - lastUpdated time.Time + loadFactor atomic.Pointer[float32] + lastUpdated atomic.Pointer[time.Time] local bool } -func (e *endpoint) Copy() Endpoint { - e.mu.RLock() - defer e.mu.RUnlock() - - return &endpoint{ - id: e.id, - address: e.address, - location: e.location, - services: append(make([]string, 0, len(e.services)), e.services...), - loadFactor: e.loadFactor, - local: e.local, - lastUpdated: e.lastUpdated, - } -} - func (e *endpoint) String() string { - e.mu.RLock() - defer e.mu.RUnlock() - return fmt.Sprintf(`{id:%d,address:%q,local:%t,location:%q,loadFactor:%f,lastUpdated:%q}`, e.id, e.address, e.local, e.location, - e.loadFactor, - e.lastUpdated.Format(time.RFC3339), + *e.loadFactor.Load(), + e.lastUpdated.Load().Format(time.RFC3339), ) } func (e *endpoint) NodeID() uint32 { - e.mu.RLock() - defer e.mu.RUnlock() - return e.id } func (e *endpoint) Address() (address string) { - e.mu.RLock() - defer e.mu.RUnlock() - return e.address } func (e *endpoint) Location() string { - e.mu.RLock() - defer e.mu.RUnlock() - return e.location } @@ -96,30 +67,19 @@ func (e *endpoint) Location() string { // Will be removed after Oct 2024. // Read about versioning policy: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#deprecated func (e *endpoint) LocalDC() bool { - e.mu.RLock() - defer e.mu.RUnlock() - return e.local } func (e *endpoint) LoadFactor() float32 { - e.mu.RLock() - defer e.mu.RUnlock() - - return e.loadFactor + return *e.loadFactor.Load() } func (e *endpoint) LastUpdated() time.Time { - e.mu.RLock() - defer e.mu.RUnlock() - - return e.lastUpdated + return *e.lastUpdated.Load() } func (e *endpoint) Touch(opts ...Option) { - e.mu.Lock() - defer e.mu.Unlock() - for _, opt := range append([]Option{withLastUpdated(time.Now())}, opts...) { + for _, opt := range append([]Option{WithLastUpdated(time.Now())}, opts...) { if opt != nil { opt(e) } @@ -128,18 +88,6 @@ func (e *endpoint) Touch(opts ...Option) { type Option func(e *endpoint) -func WithID(id uint32) Option { - return func(e *endpoint) { - e.id = id - } -} - -func WithLocation(location string) Option { - return func(e *endpoint) { - e.location = location - } -} - func WithLocalDC(local bool) Option { return func(e *endpoint) { e.local = local @@ -148,7 +96,7 @@ func WithLocalDC(local bool) Option { func WithLoadFactor(loadFactor float32) Option { return func(e *endpoint) { - e.loadFactor = loadFactor + e.loadFactor.Store(&loadFactor) } } @@ -158,17 +106,19 @@ func WithServices(services []string) Option { } } -func withLastUpdated(ts time.Time) Option { +func WithLastUpdated(ts time.Time) Option { return func(e *endpoint) { - e.lastUpdated = ts + e.lastUpdated.Store(&ts) } } -func New(address string, opts ...Option) *endpoint { +func New(nodeID uint32, address string, location string, opts ...Option) *endpoint { e := &endpoint{ - address: address, - lastUpdated: time.Now(), + id: nodeID, + address: address, + location: location, } + for _, opt := range opts { if opt != nil { opt(e) diff --git a/internal/mock/conn.go b/internal/mock/conn.go index b4ceb9f69..21ebac72f 100644 --- a/internal/mock/conn.go +++ b/internal/mock/conn.go @@ -19,6 +19,10 @@ type Conn struct { LocalDCField bool } +func (c *Conn) Address() string { + return c.AddrField +} + func (c *Conn) Invoke( ctx context.Context, method string, diff --git a/internal/params/parameters.go b/internal/params/parameters.go index aae4e38d0..005397413 100644 --- a/internal/params/parameters.go +++ b/internal/params/parameters.go @@ -61,7 +61,7 @@ func (p *Parameters) String() string { } func (p *Parameters) ToYDB(a *allocator.Allocator) map[string]*Ydb.TypedValue { - if p == nil { + if p == nil || len(*p) == 0 { return nil } parameters := make(map[string]*Ydb.TypedValue, len(*p)) diff --git a/internal/query/client.go b/internal/query/client.go index 3019e83cd..ce5916e04 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -21,7 +21,7 @@ import ( //go:generate mockgen -destination grpc_client_mock_test.go -package query -write_package_comment=false github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1 QueryServiceClient,QueryService_AttachSessionClient,QueryService_ExecuteQueryClient type nodeChecker interface { - HasNode(id uint32) bool + HasNode(id int64) bool } type balancer interface { @@ -197,7 +197,7 @@ func New(ctx context.Context, balancer balancer, cfg *config.Config) *Client { s, err := createSession(createCtx, client.grpcClient, cfg, withSessionCheck(func(s *Session) bool { - return balancer.HasNode(uint32(s.nodeID)) + return balancer.HasNode(s.nodeID) }), ) if err != nil { diff --git a/internal/query/client_test.go b/internal/query/client_test.go index 1b7260750..8ee870a23 100644 --- a/internal/query/client_test.go +++ b/internal/query/client_test.go @@ -164,7 +164,7 @@ func newTestSession(id string) *Session { func newTestSessionWithClient(id string, client Ydb_Query_V1.QueryServiceClient) *Session { return &Session{ id: id, - grpcClient: client, + client: client, statusCode: statusIdle, cfg: config.New(), } diff --git a/internal/query/errors.go b/internal/query/errors.go index 923ef8ed8..ef29fcd9f 100644 --- a/internal/query/errors.go +++ b/internal/query/errors.go @@ -7,6 +7,7 @@ import ( var ( ErrNotImplemented = errors.New("not implemented yet") errWrongNextResultSetIndex = errors.New("wrong result set index") + errNilResult = errors.New("nil result") errClosedResult = errors.New("result closed early") errClosedClient = errors.New("query client closed early") errWrongResultSetIndex = errors.New("critical violation of the logic - wrong result set index") diff --git a/internal/query/execute_query.go b/internal/query/execute_query.go index bf2fef314..990bc6c62 100644 --- a/internal/query/execute_query.go +++ b/internal/query/execute_query.go @@ -2,43 +2,29 @@ package query import ( "context" + "io" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" - "google.golang.org/grpc" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" - "github.com/ydb-platform/ydb-go-sdk/v3/query" ) -type executeConfig interface { - ExecMode() options.ExecMode - StatsMode() options.StatsMode - TxControl() *query.TransactionControl - Syntax() options.Syntax - Params() *params.Parameters - CallOptions() []grpc.CallOption +type withAllocatorOption struct { + a *allocator.Allocator } -func executeQueryRequest(a *allocator.Allocator, sessionID, q string, cfg executeConfig) ( - *Ydb_Query.ExecuteQueryRequest, - []grpc.CallOption, -) { - request := a.QueryExecuteQueryRequest() - - request.SessionId = sessionID - request.ExecMode = Ydb_Query.ExecMode(cfg.ExecMode()) - request.TxControl = cfg.TxControl().ToYDB(a) - request.Query = queryFromText(a, q, Ydb_Query.Syntax(cfg.Syntax())) - request.Parameters = cfg.Params().ToYDB(a) - request.StatsMode = Ydb_Query.StatsMode(cfg.StatsMode()) - request.ConcurrentResultSets = false +func (a withAllocatorOption) ApplyExecuteOption(s *options.Execute) { + s.Allocator = a.a +} - return request, cfg.CallOptions() +func WithAllocator(a *allocator.Allocator) options.ExecuteOption { + return withAllocatorOption{a: a} } func queryFromText( @@ -52,22 +38,71 @@ func queryFromText( return content } -func execute(ctx context.Context, s *Session, c Ydb_Query_V1.QueryServiceClient, q string, cfg executeConfig) ( - _ *transaction, _ *result, finalErr error, -) { - a := allocator.New() - defer a.Free() +func ReadAll(ctx context.Context, r *result) (resultSets []*Ydb.ResultSet, stats *Ydb_TableStats.QueryStats, _ error) { + if r == nil { + return nil, nil, xerrors.WithStackTrace(errNilResult) + } + for { + resultSet, err := r.nextResultSet(ctx) + if err != nil { + if xerrors.Is(err, io.EOF) { + return resultSets, resultSet.stats(), nil + } + + return nil, nil, xerrors.WithStackTrace(err) + } + var rows []*Ydb.Value + for { + row, err := resultSet.nextRow(ctx) + if err != nil { + if xerrors.Is(err, io.EOF) { + break + } + + return nil, nil, xerrors.WithStackTrace(err) + } + + rows = append(rows, row.v) + } + + resultSets = append(resultSets, &Ydb.ResultSet{ + Columns: resultSet.columns, + Rows: rows, + }) + } +} + +func Execute[T options.ExecuteOption]( + ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessionID, query string, opts ...T, +) (_ *transaction, _ *result, finalErr error) { + var ( + settings = options.ExecuteSettings(opts...) + a = settings.Allocator + ) + + if a == nil { + a = allocator.New() + defer a.Free() + } - request, callOptions := executeQueryRequest(a, s.id, q, cfg) + request := a.QueryExecuteQueryRequest() + + request.SessionId = sessionID + request.ExecMode = Ydb_Query.ExecMode(settings.ExecMode) + request.TxControl = settings.TxControl.ToYDB(a) + request.Query = queryFromText(a, query, Ydb_Query.Syntax(settings.Syntax)) + request.Parameters = settings.Params.ToYDB(a) + request.StatsMode = Ydb_Query.StatsMode(settings.StatsMode) + request.ConcurrentResultSets = false executeCtx, cancelExecute := xcontext.WithCancel(xcontext.ValueOnly(ctx)) - stream, err := c.ExecuteQuery(executeCtx, request, callOptions...) + stream, err := client.ExecuteQuery(executeCtx, request, settings.GrpcCallOptions...) if err != nil { return nil, nil, xerrors.WithStackTrace(err) } - r, txID, err := newResult(ctx, stream, s.cfg.Trace(), cancelExecute) + r, txID, err := newResult(ctx, stream, settings.Trace, cancelExecute) if err != nil { cancelExecute() @@ -78,5 +113,5 @@ func execute(ctx context.Context, s *Session, c Ydb_Query_V1.QueryServiceClient, return nil, r, nil } - return newTransaction(txID, s), r, nil + return newTx(txID, sessionID, client, settings.Trace), r, nil } diff --git a/internal/query/execute_query_test.go b/internal/query/execute_query_test.go index c3c14d754..89d2f2963 100644 --- a/internal/query/execute_query_test.go +++ b/internal/query/execute_query_test.go @@ -14,7 +14,6 @@ import ( "google.golang.org/grpc/metadata" grpcStatus "google.golang.org/grpc/status" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -355,12 +354,32 @@ func TestExecute(t *testing.T) { }, nil) stream.EXPECT().Recv().Return(nil, io.EOF) service := NewMockQueryServiceClient(ctrl) - service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) - tx, r, err := execute(ctx, newTestSession("123"), service, "", options.ExecuteSettings()) + service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Cond(func(x any) bool { + request, ok := x.(*Ydb_Query.ExecuteQueryRequest) + if !ok { + return false + } + if request.GetSessionId() != "123" { + return false + } + query := request.GetQueryContent() + if query == nil { + return false + } + if query.GetText() != "SELECT 1" { + return false + } + if query.GetSyntax() != Ydb_Query.Syntax_SYNTAX_YQL_V1 { + return false + } + + return true + })).Return(stream, nil) + tx, r, err := Execute[options.ExecuteOption](ctx, service, "123", "SELECT 1") require.NoError(t, err) defer r.Close(ctx) require.EqualValues(t, "456", tx.id) - require.EqualValues(t, "123", tx.s.id) + require.EqualValues(t, "123", tx.sessionID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") @@ -469,7 +488,7 @@ func TestExecute(t *testing.T) { service := NewMockQueryServiceClient(ctrl) service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(nil, grpcStatus.Error(grpcCodes.Unavailable, "")) t.Log("execute") - _, _, err := execute(ctx, newTestSession("123"), service, "", options.ExecuteSettings()) + _, _, err := Execute[options.ExecuteOption](ctx, service, "123", "") require.Error(t, err) require.True(t, xerrors.IsTransportError(err, grpcCodes.Unavailable)) }) @@ -573,11 +592,11 @@ func TestExecute(t *testing.T) { service := NewMockQueryServiceClient(ctrl) service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - tx, r, err := execute(ctx, newTestSession("123"), service, "", options.ExecuteSettings()) + tx, r, err := Execute[options.ExecuteOption](ctx, service, "123", "") require.NoError(t, err) defer r.Close(ctx) require.EqualValues(t, "456", tx.id) - require.EqualValues(t, "123", tx.s.id) + require.EqualValues(t, "123", tx.sessionID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") @@ -637,7 +656,7 @@ func TestExecute(t *testing.T) { service := NewMockQueryServiceClient(ctrl) service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - _, _, err := execute(ctx, newTestSession("123"), service, "", options.ExecuteSettings()) + _, _, err := Execute[options.ExecuteOption](ctx, service, "123", "") require.Error(t, err) require.True(t, xerrors.IsOperationError(err, Ydb.StatusIds_UNAVAILABLE)) }) @@ -713,11 +732,11 @@ func TestExecute(t *testing.T) { service := NewMockQueryServiceClient(ctrl) service.EXPECT().ExecuteQuery(gomock.Any(), gomock.Any()).Return(stream, nil) t.Log("execute") - tx, r, err := execute(ctx, newTestSession("123"), service, "", options.ExecuteSettings()) + tx, r, err := Execute[options.ExecuteOption](ctx, service, "123", "") require.NoError(t, err) defer r.Close(ctx) require.EqualValues(t, "456", tx.id) - require.EqualValues(t, "123", tx.s.id) + require.EqualValues(t, "123", tx.sessionID) require.EqualValues(t, -1, r.resultSetIndex) { t.Log("nextResultSet") @@ -757,12 +776,11 @@ func TestExecute(t *testing.T) { } func TestExecuteQueryRequest(t *testing.T) { - a := allocator.New() for _, tt := range []struct { - name string - opts []options.ExecuteOption - request *Ydb_Query.ExecuteQueryRequest - callOptions []grpc.CallOption + name string + opts []options.ExecuteOption + request *Ydb_Query.ExecuteQueryRequest + grpcCallOptions []grpc.CallOption }{ { name: "WithoutOptions", @@ -996,7 +1014,7 @@ func TestExecuteQueryRequest(t *testing.T) { StatsMode: Ydb_Query.StatsMode_STATS_MODE_NONE, ConcurrentResultSets: false, }, - callOptions: []grpc.CallOption{ + grpcCallOptions: []grpc.CallOption{ grpc.Header(&metadata.MD{ "ext-header": []string{"test"}, }), @@ -1004,9 +1022,1128 @@ func TestExecuteQueryRequest(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - request, callOptions := executeQueryRequest(a, tt.name, tt.name, options.ExecuteSettings(tt.opts...)) - require.Equal(t, request.String(), tt.request.String()) - require.Equal(t, tt.callOptions, callOptions) + ctx := xtest.Context(t) + ctrl := gomock.NewController(t) + stream := NewMockQueryService_ExecuteQueryClient(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "456", + }, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + }, nil) + client := NewMockQueryServiceClient(ctrl) + var args []any + if len(tt.grpcCallOptions) > 0 { + for _, grpcOpt := range tt.grpcCallOptions { + args = append(args, grpcOpt) + } + } + client.EXPECT().ExecuteQuery(gomock.Any(), gomock.Cond(func(x any) bool { + request, ok := x.(*Ydb_Query.ExecuteQueryRequest) + if !ok { + return false + } + + return tt.request.String() == request.String() + }), args...).Return(stream, nil) + tx, _, err := Execute(ctx, client, tt.name, tt.name, tt.opts...) + require.NoError(t, err) + require.Equal(t, tt.name, tx.sessionID) + require.Equal(t, "456", tx.id) }) } } + +func TestReadAll(t *testing.T) { + ctx := xtest.Context(t) + ctrl := gomock.NewController(t) + t.Run("NilResult", func(t *testing.T) { + resultSets, stats, err := ReadAll(ctx, nil) + require.ErrorIs(t, err, errNilResult) + require.Nil(t, resultSets) + require.Nil(t, stats) + }) + t.Run("io.EOF", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(nil, io.EOF) + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.ErrorIs(t, err, errNilResult) + require.Nil(t, resultSets) + require.Nil(t, stats) + }) + t.Run("EmptyResult", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: nil, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes() + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "123", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.NoError(t, err) + require.EqualValues(t, []*Ydb.ResultSet{ + {}, + }, resultSets) + require.Nil(t, stats) + }) + t.Run("SingleResultSet", func(t *testing.T) { + t.Run("SinglePart", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes() + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "123", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.NoError(t, err) + require.EqualValues(t, []*Ydb.ResultSet{ + { + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + }, resultSets) + require.Nil(t, stats) + }) + t.Run("TwoParts", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes() + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "123", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.NoError(t, err) + require.EqualValues(t, []*Ydb.ResultSet{ + { + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + }, resultSets) + require.Nil(t, stats) + }) + }) + t.Run("TwoResultSets", func(t *testing.T) { + t.Run("SinglePart", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 1, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "c", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "d", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes() + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "123", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.NoError(t, err) + require.EqualValues(t, []*Ydb.ResultSet{ + { + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + { + Columns: []*Ydb.Column{ + { + Name: "c", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "d", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + }, resultSets) + require.Nil(t, stats) + }) + t.Run("TwoParts", func(t *testing.T) { + stream := NewMockResultStream(ctrl) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 0, + ResultSet: &Ydb.ResultSet{ + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 1, + ResultSet: &Ydb.ResultSet{ + Columns: []*Ydb.Column{ + { + Name: "c", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "d", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(&Ydb_Query.ExecuteQueryResponsePart{ + Status: Ydb.StatusIds_SUCCESS, + ResultSetIndex: 1, + ResultSet: &Ydb.ResultSet{ + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + ExecStats: nil, + TxMeta: &Ydb_Query.TransactionMeta{ + Id: "123", + }, + }, nil) + stream.EXPECT().Recv().Return(nil, io.EOF).AnyTimes() + r, txID, err := newResult(ctx, stream, nil, nil) + require.NoError(t, err) + require.Equal(t, "123", txID) + resultSets, stats, err := ReadAll(ctx, r) + require.NoError(t, err) + require.EqualValues(t, []*Ydb.ResultSet{ + { + Columns: []*Ydb.Column{ + { + Name: "a", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "b", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + { + Columns: []*Ydb.Column{ + { + Name: "c", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UINT64, + }, + }, + }, + { + Name: "d", + Type: &Ydb.Type{ + Type: &Ydb.Type_TypeId{ + TypeId: Ydb.Type_UTF8, + }, + }, + }, + }, + Rows: []*Ydb.Value{ + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 1, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "1", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 2, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "2", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 3, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "3", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 4, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "4", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 5, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "5", + }, + }}, + }, + { + Items: []*Ydb.Value{{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: 6, + }, + }, { + Value: &Ydb.Value_TextValue{ + TextValue: "6", + }, + }}, + }, + }, + }, + }, resultSets) + require.Nil(t, stats) + }) + }) +} diff --git a/internal/query/options/execute.go b/internal/query/options/execute.go index 6a306c26d..b92dfb031 100644 --- a/internal/query/options/execute.go +++ b/internal/query/options/execute.go @@ -4,37 +4,43 @@ import ( "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" "google.golang.org/grpc" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/tx" + "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) type ( - Syntax Ydb_Query.Syntax - ExecMode Ydb_Query.ExecMode - StatsMode Ydb_Query.StatsMode - CallOptions []grpc.CallOption - commonExecuteSettings struct { - syntax Syntax - params params.Parameters - execMode ExecMode - statsMode StatsMode - callOptions []grpc.CallOption + Syntax Ydb_Query.Syntax + ExecMode Ydb_Query.ExecMode + StatsMode Ydb_Query.StatsMode + GrpcOpts []grpc.CallOption + executeSettings struct { + Syntax Syntax + Params params.Parameters + ExecMode ExecMode + StatsMode StatsMode + GrpcCallOptions []grpc.CallOption + Trace *trace.Query + Allocator *allocator.Allocator } Execute struct { - commonExecuteSettings + executeSettings - txControl *tx.Control + TxControl *tx.Control } ExecuteOption interface { - applyExecuteOption(s *Execute) + ApplyExecuteOption(s *Execute) } txExecuteSettings struct { - ExecuteSettings *Execute + ExecuteOptions []ExecuteOption - commitTx bool + CommitTx bool } TxExecuteOption interface { - applyTxExecuteOption(s *txExecuteSettings) + ExecuteOption + + ApplyTxExecuteOption(s *txExecuteSettings) } txCommitOption struct{} ParametersOption params.Parameters @@ -43,20 +49,23 @@ type ( } ) -func (opt TxControlOption) applyExecuteOption(s *Execute) { - s.txControl = opt.txControl +func (t txCommitOption) ApplyExecuteOption(s *Execute) { +} + +func (opt TxControlOption) ApplyExecuteOption(s *Execute) { + s.TxControl = opt.txControl } -func (t txCommitOption) applyTxExecuteOption(s *txExecuteSettings) { - s.commitTx = true +func (t txCommitOption) ApplyTxExecuteOption(s *txExecuteSettings) { + s.CommitTx = true } -func (syntax Syntax) applyTxExecuteOption(s *txExecuteSettings) { - syntax.applyExecuteOption(s.ExecuteSettings) +func (syntax Syntax) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, syntax) } -func (syntax Syntax) applyExecuteOption(s *Execute) { - s.syntax = syntax +func (syntax Syntax) ApplyExecuteOption(s *Execute) { + s.Syntax = syntax } const ( @@ -64,36 +73,36 @@ const ( SyntaxPostgreSQL = Syntax(Ydb_Query.Syntax_SYNTAX_PG) ) -func (params ParametersOption) applyTxExecuteOption(s *txExecuteSettings) { - params.applyExecuteOption(s.ExecuteSettings) +func (params ParametersOption) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, params) } -func (params ParametersOption) applyExecuteOption(s *Execute) { - s.params = append(s.params, params...) +func (params ParametersOption) ApplyExecuteOption(s *Execute) { + s.Params = append(s.Params, params...) } -func (opts CallOptions) applyExecuteOption(s *Execute) { - s.callOptions = append(s.callOptions, opts...) +func (opts GrpcOpts) ApplyExecuteOption(s *Execute) { + s.GrpcCallOptions = append(s.GrpcCallOptions, opts...) } -func (opts CallOptions) applyTxExecuteOption(s *txExecuteSettings) { - opts.applyExecuteOption(s.ExecuteSettings) +func (opts GrpcOpts) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, opts) } -func (mode StatsMode) applyTxExecuteOption(s *txExecuteSettings) { - mode.applyExecuteOption(s.ExecuteSettings) +func (mode StatsMode) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, mode) } -func (mode StatsMode) applyExecuteOption(s *Execute) { - s.statsMode = mode +func (mode StatsMode) ApplyExecuteOption(s *Execute) { + s.StatsMode = mode } -func (mode ExecMode) applyTxExecuteOption(s *txExecuteSettings) { - mode.applyExecuteOption(s.ExecuteSettings) +func (mode ExecMode) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, mode) } -func (mode ExecMode) applyExecuteOption(s *Execute) { - s.execMode = mode +func (mode ExecMode) ApplyExecuteOption(s *Execute) { + s.ExecMode = mode } const ( @@ -110,68 +119,37 @@ const ( StatsModeProfile = StatsMode(Ydb_Query.StatsMode_STATS_MODE_PROFILE) ) -func defaultCommonExecuteSettings() commonExecuteSettings { - return commonExecuteSettings{ - syntax: SyntaxYQL, - execMode: ExecModeExecute, - statsMode: StatsModeNone, +func defaultExecuteSettings() executeSettings { + return executeSettings{ + Syntax: SyntaxYQL, + ExecMode: ExecModeExecute, + StatsMode: StatsModeNone, + Trace: &trace.Query{}, } } -func ExecuteSettings(opts ...ExecuteOption) (settings *Execute) { +func ExecuteSettings[T ExecuteOption](opts ...T) (settings *Execute) { settings = &Execute{ - commonExecuteSettings: defaultCommonExecuteSettings(), + executeSettings: defaultExecuteSettings(), } - settings.commonExecuteSettings = defaultCommonExecuteSettings() - settings.txControl = tx.DefaultTxControl() + settings.executeSettings = defaultExecuteSettings() + settings.TxControl = tx.DefaultTxControl() for _, opt := range opts { - if opt != nil { - opt.applyExecuteOption(settings) - } + opt.ApplyExecuteOption(settings) } return settings } -func (s *Execute) TxControl() *tx.Control { - return s.txControl -} - -func (s *Execute) SetTxControl(ctrl *tx.Control) { - s.txControl = ctrl -} - -func (s *commonExecuteSettings) CallOptions() []grpc.CallOption { - return s.callOptions -} - -func (s *commonExecuteSettings) Syntax() Syntax { - return s.syntax -} - -func (s *commonExecuteSettings) ExecMode() ExecMode { - return s.execMode -} - -func (s *commonExecuteSettings) StatsMode() StatsMode { - return s.statsMode -} - -func (s *commonExecuteSettings) Params() *params.Parameters { - if len(s.params) == 0 { - return nil - } - - return &s.params -} - func TxExecuteSettings(id string, opts ...TxExecuteOption) (settings *txExecuteSettings) { settings = &txExecuteSettings{ - ExecuteSettings: ExecuteSettings(WithTxControl(tx.NewControl(tx.WithTxID(id)))), + ExecuteOptions: []ExecuteOption{ + WithTxControl(tx.NewControl(tx.WithTxID(id))), + }, } for _, opt := range opts { if opt != nil { - opt.applyTxExecuteOption(settings) + opt.ApplyTxExecuteOption(settings) } } @@ -191,6 +169,8 @@ var ( _ TxExecuteOption = StatsMode(0) _ TxExecuteOption = txCommitOption{} _ ExecuteOption = TxControlOption{} + _ TxExecuteOption = traceOption{} + _ ExecuteOption = traceOption{} ) func WithCommit() txCommitOption { @@ -215,7 +195,7 @@ func WithStatsMode(mode StatsMode) StatsMode { return mode } -func WithCallOptions(opts ...grpc.CallOption) CallOptions { +func WithCallOptions(opts ...grpc.CallOption) GrpcOpts { return opts } diff --git a/internal/query/options/retry.go b/internal/query/options/retry.go index b604152e3..21bd9723f 100644 --- a/internal/query/options/retry.go +++ b/internal/query/options/retry.go @@ -36,11 +36,8 @@ type ( txSettings tx.Settings } - idempotentOption struct{} - labelOption string - traceOption struct { - t *trace.Query - } + idempotentOption struct{} + labelOption string doTxSettingsOption struct { txSettings tx.Settings } @@ -70,14 +67,6 @@ func (idempotentOption) applyDoOption(s *doSettings) { s.retryOpts = append(s.retryOpts, retry.WithIdempotent(true)) } -func (opt traceOption) applyDoOption(s *doSettings) { - s.trace = s.trace.Compose(opt.t) -} - -func (opt traceOption) applyDoTxOption(s *doTxSettings) { - s.doOpts = append(s.doOpts, opt) -} - func (opt labelOption) applyDoOption(s *doSettings) { s.retryOpts = append(s.retryOpts, retry.WithLabel(string(opt))) } diff --git a/internal/query/options/trace.go b/internal/query/options/trace.go new file mode 100644 index 000000000..3dfaef05d --- /dev/null +++ b/internal/query/options/trace.go @@ -0,0 +1,23 @@ +package options + +import "github.com/ydb-platform/ydb-go-sdk/v3/trace" + +type traceOption struct { + t *trace.Query +} + +func (opt traceOption) ApplyTxExecuteOption(s *txExecuteSettings) { + s.ExecuteOptions = append(s.ExecuteOptions, opt) +} + +func (opt traceOption) ApplyExecuteOption(s *Execute) { + s.Trace = s.Trace.Compose(opt.t) +} + +func (opt traceOption) applyDoOption(s *doSettings) { + s.trace = s.trace.Compose(opt.t) +} + +func (opt traceOption) applyDoTxOption(s *doTxSettings) { + s.doOpts = append(s.doOpts, opt) +} diff --git a/internal/query/result.go b/internal/query/result.go index 78478610d..32e8b767f 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -5,7 +5,6 @@ import ( "fmt" "io" - "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" @@ -15,21 +14,28 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) +//go:generate mockgen -destination result_stream_mock_test.go -package query -write_package_comment=false . ResultStream + var _ query.Result = (*result)(nil) -type result struct { - stream Ydb_Query_V1.QueryService_ExecuteQueryClient - closeOnce func(ctx context.Context) error - lastPart *Ydb_Query.ExecuteQueryResponsePart - resultSetIndex int64 - errs []error - closed chan struct{} - trace *trace.Query -} +type ( + ResultStream interface { + Recv() (*Ydb_Query.ExecuteQueryResponsePart, error) + } + result struct { + stream ResultStream + closeOnce func(ctx context.Context) error + lastPart *Ydb_Query.ExecuteQueryResponsePart + resultSetIndex int64 + errs []error + closed chan struct{} + trace *trace.Query + } +) func newResult( ctx context.Context, - stream Ydb_Query_V1.QueryService_ExecuteQueryClient, + stream ResultStream, t *trace.Query, closeResult context.CancelFunc, ) (_ *result, txID string, err error) { @@ -53,6 +59,10 @@ func newResult( default: part, err := nextPart(ctx, stream, t) if err != nil { + if xerrors.Is(err, io.EOF) { + return nil, txID, nil + } + return nil, txID, xerrors.WithStackTrace(err) } var ( @@ -81,7 +91,7 @@ func newResult( func nextPart( ctx context.Context, - stream Ydb_Query_V1.QueryService_ExecuteQueryClient, + stream ResultStream, t *trace.Query, ) (_ *Ydb_Query.ExecuteQueryResponsePart, finalErr error) { if t == nil { @@ -130,7 +140,7 @@ func (r *result) nextResultSet(ctx context.Context) (_ *resultSet, err error) { case <-ctx.Done(): return nil, xerrors.WithStackTrace(ctx.Err()) default: - if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex { //nolint:nestif + if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex { r.resultSetIndex = resultSetIndex return newResultSet(func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) { @@ -147,10 +157,6 @@ func (r *result) nextResultSet(ctx context.Context) (_ *resultSet, err error) { default: part, err := nextPart(ctx, r.stream, r.trace) if err != nil { - if xerrors.Is(err, io.EOF) { - _ = r.closeOnce(ctx) - } - return nil, xerrors.WithStackTrace(err) } r.lastPart = part diff --git a/internal/query/result_set.go b/internal/query/result_set.go index 8d2ae8b71..cf2299cc9 100644 --- a/internal/query/result_set.go +++ b/internal/query/result_set.go @@ -7,6 +7,7 @@ import ( "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -46,6 +47,14 @@ func newResultSet( } } +func (rs *resultSet) stats() *Ydb_TableStats.QueryStats { + if rs == nil || rs.currentPart == nil { + return nil + } + + return rs.currentPart.GetExecStats() +} + func (rs *resultSet) nextRow(ctx context.Context) (*row, error) { rs.rowIndex++ for { diff --git a/internal/query/result_stream_mock_test.go b/internal/query/result_stream_mock_test.go new file mode 100644 index 000000000..4a31558b8 --- /dev/null +++ b/internal/query/result_stream_mock_test.go @@ -0,0 +1,52 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ydb-platform/ydb-go-sdk/v3/internal/query (interfaces: ResultStream) +// +// Generated by this command: +// +// mockgen -destination result_stream_mock_test.go -package query -write_package_comment=false . ResultStream +package query + +import ( + reflect "reflect" + + Ydb_Query "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" + gomock "go.uber.org/mock/gomock" +) + +// MockResultStream is a mock of ResultStream interface. +type MockResultStream struct { + ctrl *gomock.Controller + recorder *MockResultStreamMockRecorder +} + +// MockResultStreamMockRecorder is the mock recorder for MockResultStream. +type MockResultStreamMockRecorder struct { + mock *MockResultStream +} + +// NewMockResultStream creates a new mock instance. +func NewMockResultStream(ctrl *gomock.Controller) *MockResultStream { + mock := &MockResultStream{ctrl: ctrl} + mock.recorder = &MockResultStreamMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResultStream) EXPECT() *MockResultStreamMockRecorder { + return m.recorder +} + +// Recv mocks base method. +func (m *MockResultStream) Recv() (*Ydb_Query.ExecuteQueryResponsePart, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*Ydb_Query.ExecuteQueryResponsePart) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockResultStreamMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockResultStream)(nil).Recv)) +} diff --git a/internal/query/row.go b/internal/query/row.go index 476b6aa14..1a0ce2c5c 100644 --- a/internal/query/row.go +++ b/internal/query/row.go @@ -16,6 +16,7 @@ var _ query.Row = (*row)(nil) type row struct { ctx context.Context trace *trace.Query + v *Ydb.Value indexedScanner scanner.IndexedScanner namedScanner scanner.NamedScanner @@ -28,6 +29,7 @@ func newRow(ctx context.Context, columns []*Ydb.Column, v *Ydb.Value, t *trace.Q return &row{ ctx: ctx, trace: t, + v: v, indexedScanner: scanner.Indexed(data), namedScanner: scanner.Named(data), structScanner: scanner.Struct(data), diff --git a/internal/query/session.go b/internal/query/session.go index 708b36df6..c089f1669 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -26,7 +26,7 @@ type ( cfg *config.Config id string nodeID int64 - grpcClient Ydb_Query_V1.QueryServiceClient + client Ydb_Query_V1.QueryServiceClient statusCode statusCode closeOnce func(ctx context.Context) error checks []func(s *Session) bool @@ -45,7 +45,7 @@ func createSession( ) (s *Session, finalErr error) { s = &Session{ cfg: cfg, - grpcClient: client, + client: client, statusCode: statusUnknown, checks: []func(*Session) bool{ func(s *Session) bool { @@ -68,7 +68,7 @@ func createSession( opt(s) } - onDone := trace.QueryOnSessionCreate(s.cfg.Trace(), &ctx, + onDone := trace.QueryOnSessionCreate(s.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.createSession"), ) defer func() { @@ -99,8 +99,16 @@ func createSession( return s, nil } +func (s *Session) Trace() *trace.Query { + if s == nil || s.cfg == nil { + return nil + } + + return s.cfg.Trace() +} + func (s *Session) attach(ctx context.Context) (finalErr error) { - onDone := trace.QueryOnSessionAttach(s.cfg.Trace(), &ctx, + onDone := trace.QueryOnSessionAttach(s.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.(*Session).attach"), s) defer func() { onDone(finalErr) @@ -108,7 +116,7 @@ func (s *Session) attach(ctx context.Context) (finalErr error) { attachCtx, cancelAttach := xcontext.WithCancel(xcontext.ValueOnly(ctx)) - attach, err := s.grpcClient.AttachSession(attachCtx, &Ydb_Query.AttachSessionRequest{ + attach, err := s.client.AttachSession(attachCtx, &Ydb_Query.AttachSessionRequest{ SessionId: s.id, }) if err != nil { @@ -136,7 +144,7 @@ func (s *Session) attach(ctx context.Context) (finalErr error) { } defer cancel() - if err = deleteSession(ctx, s.grpcClient, s.id); err != nil { + if err = deleteSession(ctx, s.client, s.id); err != nil { return xerrors.WithStackTrace(err) } @@ -192,7 +200,7 @@ func (s *Session) IsAlive() bool { } func (s *Session) Close(ctx context.Context) (err error) { - onDone := trace.QueryOnSessionDelete(s.cfg.Trace(), &ctx, + onDone := trace.QueryOnSessionDelete(s.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.(*Session).Close"), s) defer func() { onDone(err) @@ -223,7 +231,7 @@ func begin( return nil, xerrors.WithStackTrace(err) } - return newTransaction(response.GetTxMeta().GetId(), s), nil + return newTx(response.GetTxMeta().GetId(), s.id, s.client, s.Trace()), nil } func (s *Session) Begin( @@ -234,17 +242,16 @@ func (s *Session) Begin( ) { var tx *transaction - onDone := trace.QueryOnSessionBegin(s.cfg.Trace(), &ctx, + onDone := trace.QueryOnSessionBegin(s.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.(*Session).Begin"), s) defer func() { onDone(err, tx) }() - tx, err = begin(ctx, s.grpcClient, s, txSettings) + tx, err = begin(ctx, s.client, s, txSettings) if err != nil { return nil, xerrors.WithStackTrace(err) } - tx.s = s return tx, nil } @@ -272,13 +279,18 @@ func (s *Session) Status() string { func (s *Session) Execute( ctx context.Context, q string, opts ...options.ExecuteOption, ) (_ query.Transaction, _ query.Result, err error) { - onDone := trace.QueryOnSessionExecute(s.cfg.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.(*Session).Execute"), s, q) + onDone := trace.QueryOnSessionExecute(s.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.(*Session).Execute"), + s, q, + ) defer func() { onDone(err) }() - tx, r, err := execute(ctx, s, s.grpcClient, q, options.ExecuteSettings(opts...)) + a := allocator.New() + defer a.Free() + + tx, r, err := Execute(ctx, s.client, s.ID(), q, append(opts, WithAllocator(a), options.WithTrace(s.Trace()))...) if err != nil { return nil, nil, xerrors.WithStackTrace(err) } diff --git a/internal/query/transaction.go b/internal/query/transaction.go index fbfcd9151..838d72291 100644 --- a/internal/query/transaction.go +++ b/internal/query/transaction.go @@ -6,6 +6,7 @@ import ( "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -16,14 +17,19 @@ import ( var _ query.Transaction = (*transaction)(nil) type transaction struct { - id string - s *Session + id string + sessionID string + + client Ydb_Query_V1.QueryServiceClient + trace *trace.Query } -func newTransaction(id string, s *Session) *transaction { +func newTx(txID, sessionID string, client Ydb_Query_V1.QueryServiceClient, trace *trace.Query) *transaction { return &transaction{ - id: id, - s: s, + id: txID, + sessionID: sessionID, + client: client, + trace: trace, } } @@ -34,13 +40,18 @@ func (tx transaction) ID() string { func (tx transaction) Execute(ctx context.Context, q string, opts ...options.TxExecuteOption) ( r query.Result, finalErr error, ) { - onDone := trace.QueryOnTxExecute(tx.s.cfg.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.transaction.Execute"), tx.s, tx, q) + onDone := trace.QueryOnTxExecute(tx.trace, &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/query.transaction.Execute"), tx, q) defer func() { onDone(finalErr) }() - _, res, err := execute(ctx, tx.s, tx.s.grpcClient, q, options.TxExecuteSettings(tx.id, opts...).ExecuteSettings) + a := allocator.New() + defer a.Free() + + settings := options.TxExecuteSettings(tx.id, opts...) + + _, res, err := Execute(ctx, tx.client, tx.sessionID, q, settings.ExecuteOptions...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -61,7 +72,7 @@ func commitTx(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi } func (tx transaction) CommitTx(ctx context.Context) (err error) { - return commitTx(ctx, tx.s.grpcClient, tx.s.id, tx.id) + return commitTx(ctx, tx.client, tx.sessionID, tx.id) } func rollback(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessionID, txID string) error { @@ -77,5 +88,5 @@ func rollback(ctx context.Context, client Ydb_Query_V1.QueryServiceClient, sessi } func (tx transaction) Rollback(ctx context.Context) (err error) { - return rollback(ctx, tx.s.grpcClient, tx.s.id, tx.id) + return rollback(ctx, tx.client, tx.sessionID, tx.id) } diff --git a/internal/query/transaction_test.go b/internal/query/transaction_test.go index 83ecdfc6c..adac6977c 100644 --- a/internal/query/transaction_test.go +++ b/internal/query/transaction_test.go @@ -132,14 +132,13 @@ func (s testExecuteSettings) CallOptions() []grpc.CallOption { return s.callOptions } -var _ executeConfig = testExecuteSettings{} - func TestTxExecuteSettings(t *testing.T) { for _, tt := range []struct { name string txID string txOpts []options.TxExecuteOption - settings executeConfig + settings testExecuteSettings + commitTx bool }{ { name: "WithTxID", @@ -218,16 +217,32 @@ func TestTxExecuteSettings(t *testing.T) { params: params.Builder{}.Param("$a").Text("A").Build(), }, }, + { + name: "WithCommit", + txID: "test", + txOpts: []options.TxExecuteOption{ + options.WithCommit(), + }, + settings: testExecuteSettings{ + execMode: options.ExecModeExecute, + statsMode: options.StatsModeNone, + txControl: query.TxControl(query.WithTxID("test")), + syntax: options.SyntaxYQL, + }, + commitTx: true, + }, } { t.Run(tt.name, func(t *testing.T) { a := allocator.New() - settings := options.TxExecuteSettings(tt.txID, tt.txOpts...).ExecuteSettings - require.Equal(t, tt.settings.Syntax(), settings.Syntax()) - require.Equal(t, tt.settings.ExecMode(), settings.ExecMode()) - require.Equal(t, tt.settings.StatsMode(), settings.StatsMode()) - require.Equal(t, tt.settings.TxControl().ToYDB(a).String(), settings.TxControl().ToYDB(a).String()) - require.Equal(t, tt.settings.Params().ToYDB(a), settings.Params().ToYDB(a)) - require.Equal(t, tt.settings.CallOptions(), settings.CallOptions()) + settings := options.TxExecuteSettings(tt.txID, tt.txOpts...) + executeSettings := options.ExecuteSettings(settings.ExecuteOptions...) + require.Equal(t, tt.settings.Syntax(), executeSettings.Syntax) + require.Equal(t, tt.settings.ExecMode(), executeSettings.ExecMode) + require.Equal(t, tt.settings.StatsMode(), executeSettings.StatsMode) + require.Equal(t, tt.settings.TxControl().ToYDB(a).String(), executeSettings.TxControl.ToYDB(a).String()) + require.Equal(t, tt.settings.Params().ToYDB(a), executeSettings.Params.ToYDB(a)) + require.Equal(t, tt.settings.CallOptions(), executeSettings.GrpcCallOptions) + require.Equal(t, tt.commitTx, settings.CommitTx) }) } } diff --git a/internal/query/tx/control.go b/internal/query/tx/control.go index a5be2fb21..922d795b7 100644 --- a/internal/query/tx/control.go +++ b/internal/query/tx/control.go @@ -21,8 +21,8 @@ type ( applyTxControlOption(txControl *Control) } Control struct { - selector Selector - commit bool + Selector Selector + Commit bool } Identifier interface { ID() string @@ -35,8 +35,8 @@ func (ctrl *Control) ToYDB(a *allocator.Allocator) *Ydb_Query.TransactionControl } txControl := a.QueryTransactionControl() - ctrl.selector.applyTxSelector(a, txControl) - txControl.CommitTx = ctrl.commit + ctrl.Selector.applyTxSelector(a, txControl) + txControl.CommitTx = ctrl.Commit return txControl } @@ -49,7 +49,7 @@ var ( type beginTxOptions []Option func (opts beginTxOptions) applyTxControlOption(txControl *Control) { - txControl.selector = opts + txControl.Selector = opts } func (opts beginTxOptions) applyTxSelector(a *allocator.Allocator, txControl *Ydb_Query.TransactionControl) { @@ -76,7 +76,7 @@ var ( type txIDTxControlOption string func (id txIDTxControlOption) applyTxControlOption(txControl *Control) { - txControl.selector = id + txControl.Selector = id } func (id txIDTxControlOption) applyTxSelector(a *allocator.Allocator, txControl *Ydb_Query.TransactionControl) { @@ -96,7 +96,7 @@ func WithTxID(txID string) txIDTxControlOption { type commitTxOption struct{} func (c commitTxOption) applyTxControlOption(txControl *Control) { - txControl.commit = true + txControl.Commit = true } // CommitTx returns commit transaction control option @@ -107,8 +107,8 @@ func CommitTx() ControlOption { // NewControl makes transaction control from given options func NewControl(opts ...ControlOption) *Control { txControl := &Control{ - selector: BeginTx(WithSerializableReadWrite()), - commit: false, + Selector: BeginTx(WithSerializableReadWrite()), + Commit: false, } for _, opt := range opts { if opt != nil { diff --git a/internal/table/client.go b/internal/table/client.go index 7d063346b..5d9e4bad9 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -26,7 +26,7 @@ import ( type sessionBuilder func(ctx context.Context) (*session, error) type nodeChecker interface { - HasNode(id uint32) bool + HasNode(id int64) bool } type balancer interface { diff --git a/internal/table/client_test.go b/internal/table/client_test.go index 82b7eb6af..e1bf06882 100644 --- a/internal/table/client_test.go +++ b/internal/table/client_test.go @@ -937,7 +937,7 @@ func TestDeadlockOnUpdateNodes(t *testing.T) { ctx, cancel := xcontext.WithTimeout(context.Background(), 1*time.Second) defer cancel() var ( - nodes = make([]uint32, 0, 3) + nodes = make([]int64, 0, 3) nodeIDCounter = uint32(0) ) balancer := testutil.NewBalancer(testutil.WithInvokeHandlers(testutil.InvokeHandlers{ @@ -980,7 +980,7 @@ func TestDeadlockOnInternalPoolGCTick(t *testing.T) { ctx, cancel := xcontext.WithTimeout(context.Background(), 1*time.Second) defer cancel() var ( - nodes = make([]uint32, 0, 3) + nodes = make([]int64, 0, 3) nodeIDCounter = uint32(0) ) balancer := testutil.NewBalancer(testutil.WithInvokeHandlers(testutil.InvokeHandlers{ diff --git a/internal/table/data_query.go b/internal/table/data_query.go index f918b7192..9b87b435e 100644 --- a/internal/table/data_query.go +++ b/internal/table/data_query.go @@ -7,7 +7,7 @@ import ( ) type ( - query interface { + queryRenameMe interface { String() string ID() string YQL() string @@ -59,11 +59,11 @@ func (q preparedDataQuery) toYDB(a *allocator.Allocator) *Ydb_Table.Query { return query } -func queryFromText(s string) query { +func queryFromText(s string) queryRenameMe { return textDataQuery(s) } -func queryPrepared(id, query string) query { +func queryPrepared(id, query string) queryRenameMe { return preparedDataQuery{ id: id, query: query, diff --git a/internal/table/session.go b/internal/table/session.go index 98e8ce4db..93b349d35 100644 --- a/internal/table/session.go +++ b/internal/table/session.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" "github.com/ydb-platform/ydb-go-genproto/Ydb_Table_V1" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Table" @@ -23,6 +24,8 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/meta" "github.com/ydb-platform/ydb-go-sdk/v3/internal/operation" "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" + options2 "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/scanner" @@ -45,15 +48,16 @@ import ( // Note that after session is no longer needed it should be destroyed by // Close() call. type session struct { - onClose []func(s *session) - id string - tableService Ydb_Table_V1.TableServiceClient - status table.SessionStatus - config *config.Config - lastUsage atomic.Int64 - statusMtx sync.RWMutex - closeOnce sync.Once - nodeID atomic.Uint32 + onClose []func(s *session) + id string + client Ydb_Table_V1.TableServiceClient + queryClient Ydb_Query_V1.QueryServiceClient + status table.SessionStatus + config *config.Config + lastUsage atomic.Int64 + statusMtx sync.RWMutex + closeOnce sync.Once + nodeID uint32 } func (s *session) LastUsage() time.Time { @@ -77,16 +81,8 @@ func (s *session) NodeID() uint32 { if s == nil { return 0 } - if id := s.nodeID.Load(); id != 0 { - return id - } - id, err := nodeID(s.id) - if err != nil { - return 0 - } - s.nodeID.Store(id) - return id + return s.nodeID } func (s *session) Status() table.SessionStatus { @@ -145,14 +141,31 @@ func newSession(ctx context.Context, cc grpc.ClientConnInterface, config *config return nil, xerrors.WithStackTrace(err) } + nodeID, err := nodeID(s.id) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + s = &session{ id: result.GetSessionId(), + nodeID: nodeID, config: config, status: table.SessionReady, } s.lastUsage.Store(time.Now().Unix()) - s.tableService = Ydb_Table_V1.NewTableServiceClient( + s.client = Ydb_Table_V1.NewTableServiceClient( + conn.WithBeforeFunc( + conn.WithContextModifier(cc, func(ctx context.Context) context.Context { + return meta.WithTrailerCallback(balancerContext.WithEndpoint(ctx, s), s.checkCloseHint) + }), + func() { + s.lastUsage.Store(time.Now().Unix()) + }, + ), + ) + + s.queryClient = Ydb_Query_V1.NewQueryServiceClient( conn.WithBeforeFunc( conn.WithContextModifier(cc, func(ctx context.Context) context.Context { return meta.WithTrailerCallback(balancerContext.WithEndpoint(ctx, s), s.checkCloseHint) @@ -190,7 +203,7 @@ func (s *session) Close(ctx context.Context) (err error) { }() if time.Since(s.LastUsage()) < s.config.IdleThreshold() { - _, err = s.tableService.DeleteSession(ctx, + _, err = s.client.DeleteSession(ctx, &Ydb_Table.DeleteSessionRequest{ SessionId: s.id, OperationParams: operation.Params(ctx, @@ -241,7 +254,7 @@ func (s *session) KeepAlive(ctx context.Context) (err error) { onDone(err) }() - resp, err := s.tableService.KeepAlive(ctx, + resp, err := s.client.KeepAlive(ctx, &Ydb_Table.KeepAliveRequest{ SessionId: s.id, OperationParams: operation.Params( @@ -296,7 +309,7 @@ func (s *session) CreateTable( opt.ApplyCreateTableOption((*options.CreateTableDesc)(&request), a) } } - _, err = s.tableService.CreateTable(ctx, &request) + _, err = s.client.CreateTable(ctx, &request) if err != nil { return xerrors.WithStackTrace(err) } @@ -329,7 +342,7 @@ func (s *session) DescribeTable( opt((*options.DescribeTableDesc)(&request)) } } - response, err = s.tableService.DescribeTable(ctx, &request) + response, err = s.client.DescribeTable(ctx, &request) if err != nil { return desc, xerrors.WithStackTrace(err) } @@ -478,7 +491,7 @@ func (s *session) DropTable( opt.ApplyDropTableOption((*options.DropTableDesc)(&request)) } } - _, err = s.tableService.DropTable(ctx, &request) + _, err = s.client.DropTable(ctx, &request) return xerrors.WithStackTrace(err) } @@ -517,7 +530,7 @@ func (s *session) AlterTable( opt.ApplyAlterTableOption((*options.AlterTableDesc)(&request), a) } } - _, err = s.tableService.AlterTable(ctx, &request) + _, err = s.client.AlterTable(ctx, &request) return xerrors.WithStackTrace(err) } @@ -544,7 +557,7 @@ func (s *session) CopyTable( opt((*options.CopyTableDesc)(&request)) } } - _, err = s.tableService.CopyTable(ctx, &request) + _, err = s.client.CopyTable(ctx, &request) if err != nil { return xerrors.WithStackTrace(err) } @@ -594,7 +607,7 @@ func (s *session) CopyTables( ctx context.Context, opts ...options.CopyTablesOption, ) (err error) { - err = copyTables(ctx, s.id, s.config.OperationTimeout(), s.config.OperationCancelAfter(), s.tableService, opts...) + err = copyTables(ctx, s.id, s.config.OperationTimeout(), s.config.OperationCancelAfter(), s.client, opts...) if err != nil { return xerrors.WithStackTrace(err) } @@ -644,7 +657,7 @@ func (s *session) RenameTables( ctx context.Context, opts ...options.RenameTablesOption, ) (err error) { - err = renameTables(ctx, s.id, s.config.OperationTimeout(), s.config.OperationCancelAfter(), s.tableService, opts...) + err = renameTables(ctx, s.id, s.config.OperationTimeout(), s.config.OperationCancelAfter(), s.client, opts...) if err != nil { return xerrors.WithStackTrace(err) } @@ -677,7 +690,7 @@ func (s *session) Explain( } }() - response, err = s.tableService.ExplainDataQuery(ctx, + response, err = s.client.ExplainDataQuery(ctx, &Ydb_Table.ExplainDataQueryRequest{ SessionId: s.id, YqlText: query, @@ -726,7 +739,7 @@ func (s *session) Prepare(ctx context.Context, queryText string) (_ table.Statem } }() - response, err = s.tableService.PrepareDataQuery(ctx, + response, err = s.client.PrepareDataQuery(ctx, &Ydb_Table.PrepareDataQueryRequest{ SessionId: s.id, YqlText: queryText, @@ -760,15 +773,15 @@ func (s *session) Prepare(ctx context.Context, queryText string) (_ table.Statem func (s *session) Execute( ctx context.Context, txControl *table.TransactionControl, - query string, + q string, parameters *params.Parameters, opts ...options.ExecuteDataQueryOption, ) ( - txr table.Transaction, r result.Result, err error, + _ table.Transaction, r result.Result, finalErr error, ) { var ( a = allocator.New() - q = queryFromText(query) + qq = queryFromText(q) request = options.ExecuteDataQueryDesc{ ExecuteDataQueryRequest: a.TableExecuteDataQueryRequest(), IgnoreTruncated: s.config.IgnoreTruncated(), @@ -780,7 +793,7 @@ func (s *session) Execute( request.SessionId = s.id request.TxControl = txControl.Desc() request.Parameters = parameters.ToYDB(a) - request.Query = q.toYDB(a) + request.Query = qq.toYDB(a) request.QueryCachePolicy = a.TableQueryCachePolicy() request.QueryCachePolicy.KeepInCache = len(request.Parameters) > 0 request.OperationParams = operation.Params(ctx, @@ -795,22 +808,58 @@ func (s *session) Execute( } } - onDone := trace.TableOnSessionQueryExecute( - s.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/table.(*session).Execute"), - s, q, parameters, - request.QueryCachePolicy.GetKeepInCache(), + var ( + onDone = trace.TableOnSessionQueryExecute( + s.config.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/table.(*session).Execute"), + s, qq, parameters, + request.QueryCachePolicy.GetKeepInCache(), + ) + txr *transaction ) defer func() { - onDone(txr, false, r, err) + onDone(txr, false, r, finalErr) }() - result, err := s.executeDataQuery(ctx, a, request.ExecuteDataQueryRequest, callOptions...) + if !request.WithQueryService { + result, err := s.executeDataQuery(ctx, a, request.ExecuteDataQueryRequest, callOptions...) + if err != nil { + return nil, nil, xerrors.WithStackTrace(err) + } + + return s.executeQueryResult(result, request.TxControl, request.IgnoreTruncated) + } + + executeOptions := make([]options2.ExecuteOption, 0, len(opts)+1) + executeOptions = append(executeOptions, query.WithAllocator(a)) + for _, opt := range opts { + executeOptions = append(executeOptions, opt) + } + tx, res, err := query.Execute(ctx, s.queryClient, s.ID(), q, executeOptions...) if err != nil { return nil, nil, xerrors.WithStackTrace(err) } + if tx != nil { + txr = &transaction{ + id: tx.ID(), + s: s, + } + if txControl.Desc().GetCommitTx() { + txr.state.Set(txStateCommitted) + } else { + txr.state.Set(txStateInitialized) + txr.control = table.TxControl(table.WithTxID(tx.ID())) + } + } - return s.executeQueryResult(result, request.TxControl, request.IgnoreTruncated) + resultSets, stats, err := query.ReadAll(ctx, res) + if err != nil { + return nil, nil, xerrors.WithStackTrace(err) + } + + return txr, scanner.NewUnary(resultSets, stats, + scanner.WithIgnoreTruncated(true), + ), nil } // executeQueryResult returns Transaction and result built from received @@ -827,9 +876,9 @@ func (s *session) executeQueryResult( s: s, } if txControl.GetCommitTx() { - tx.state.Store(txStateCommitted) + tx.state.Set(txStateCommitted) } else { - tx.state.Store(txStateInitialized) + tx.state.Set(txStateInitialized) tx.control = table.TxControl(table.WithTxID(tx.id)) } @@ -853,7 +902,7 @@ func (s *session) executeDataQuery( response *Ydb_Table.ExecuteDataQueryResponse ) - response, err = s.tableService.ExecuteDataQuery(ctx, request, callOptions...) + response, err = s.client.ExecuteDataQuery(ctx, request, callOptions...) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -887,7 +936,7 @@ func (s *session) ExecuteSchemeQuery( opt((*options.ExecuteSchemeQueryDesc)(&request)) } } - _, err = s.tableService.ExecuteSchemeQuery(ctx, &request) + _, err = s.client.ExecuteSchemeQuery(ctx, &request) return xerrors.WithStackTrace(err) } @@ -909,7 +958,7 @@ func (s *session) DescribeTableOptions(ctx context.Context) ( operation.ModeSync, ), } - response, err = s.tableService.DescribeTableOptions(ctx, &request) + response, err = s.client.DescribeTableOptions(ctx, &request) if err != nil { return desc, xerrors.WithStackTrace(err) } @@ -1060,7 +1109,7 @@ func (s *session) StreamReadTable( ctx, cancel := xcontext.WithCancel(ctx) - stream, err = s.tableService.StreamReadTable(ctx, &request) + stream, err = s.client.StreamReadTable(ctx, &request) if err != nil { cancel() @@ -1122,7 +1171,7 @@ func (s *session) ReadRows( } } - response, err = s.tableService.ReadRows(ctx, &request) + response, err = s.client.ReadRows(ctx, &request) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -1180,7 +1229,7 @@ func (s *session) StreamExecuteScanQuery( ctx, cancel := xcontext.WithCancel(ctx) - stream, err = s.tableService.StreamExecuteScanQuery(ctx, &request, callOptions...) + stream, err = s.client.StreamExecuteScanQuery(ctx, &request, callOptions...) if err != nil { cancel() @@ -1242,7 +1291,7 @@ func (s *session) BulkUpsert(ctx context.Context, table string, rows value.Value } } - _, err = s.tableService.BulkUpsert(ctx, + _, err = s.client.BulkUpsert(ctx, &Ydb_Table.BulkUpsertRequest{ Table: table, Rows: value.ToYDB(rows, a), @@ -1281,7 +1330,7 @@ func (s *session) BeginTransaction( onDone(x, err) }() - response, err = s.tableService.BeginTransaction(ctx, + response, err = s.client.BeginTransaction(ctx, &Ydb_Table.BeginTransactionRequest{ SessionId: s.id, TxSettings: txSettings.Settings(), @@ -1305,7 +1354,7 @@ func (s *session) BeginTransaction( s: s, control: table.TxControl(table.WithTxID(result.GetTxMeta().GetId())), } - tx.state.Store(txStateInitialized) + tx.state.Set(txStateInitialized) return tx, nil } diff --git a/internal/table/session_test.go b/internal/table/session_test.go index 710bd57b6..91c3ea9d8 100644 --- a/internal/table/session_test.go +++ b/internal/table/session_test.go @@ -240,8 +240,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TableExecuteDataQuery, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } _, _, err := s.Execute(ctx, table.TxControl(), "", table.NewQueryParameters()) require.NoError(t, err) @@ -251,8 +251,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TableExplainDataQuery, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } _, err := s.Explain(ctx, "") require.NoError(t, err) @@ -262,8 +262,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TablePrepareDataQuery, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } _, err := s.Prepare(ctx, "") require.NoError(t, err) @@ -280,8 +280,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TableDeleteSession, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } require.NoError(t, s.Close(ctx)) }, @@ -290,8 +290,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TableBeginTransaction, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } _, err := s.BeginTransaction(ctx, table.TxSettings()) require.NoError(t, err) @@ -302,8 +302,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { do: func(t *testing.T, ctx context.Context, c *Client) { tx := &transaction{ s: &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), }, } _, err := tx.CommitTx(ctx) @@ -315,8 +315,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { do: func(t *testing.T, ctx context.Context, c *Client) { tx := &transaction{ s: &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), }, } err := tx.Rollback(ctx) @@ -327,8 +327,8 @@ func TestSessionOperationModeOnExecuteDataQuery(t *testing.T) { method: testutil.TableKeepAlive, do: func(t *testing.T, ctx context.Context, c *Client) { s := &session{ - tableService: Ydb_Table_V1.NewTableServiceClient(c.cc), - config: config.New(), + client: Ydb_Table_V1.NewTableServiceClient(c.cc), + config: config.New(), } require.NoError(t, s.KeepAlive(ctx)) }, diff --git a/internal/table/statement.go b/internal/table/statement.go index eb797e2ef..27cdd46ed 100644 --- a/internal/table/statement.go +++ b/internal/table/statement.go @@ -20,7 +20,7 @@ import ( type statement struct { session *session - query query + query queryRenameMe params map[string]*Ydb.Type } diff --git a/internal/table/transaction.go b/internal/table/transaction.go index a07576196..35bdfd9b0 100644 --- a/internal/table/transaction.go +++ b/internal/table/transaction.go @@ -25,15 +25,20 @@ var ( ) type txState struct { - rawVal atomic.Uint32 + atomic.Pointer[txStateEnum] } -func (s *txState) Load() txStateEnum { - return txStateEnum(s.rawVal.Load()) +func (s *txState) Get() txStateEnum { + ptr := s.Pointer.Load() + if ptr == nil { + return txStateInitialized + } + + return *ptr } -func (s *txState) Store(val txStateEnum) { - s.rawVal.Store(uint32(val)) +func (s *txState) Set(state txStateEnum) { + s.Pointer.Store(&state) } type txStateEnum uint32 @@ -52,6 +57,10 @@ type transaction struct { } func (tx *transaction) ID() string { + if tx == nil { + return "" + } + return tx.id } @@ -70,7 +79,7 @@ func (tx *transaction) Execute( onDone(r, err) }() - switch tx.state.Load() { + switch tx.state.Get() { case txStateCommitted: return nil, xerrors.WithStackTrace(errTxAlreadyCommitted) case txStateRollbacked: @@ -82,7 +91,7 @@ func (tx *transaction) Execute( } if tx.control.Desc().GetCommitTx() { - tx.state.Store(txStateCommitted) + tx.state.Set(txStateCommitted) } return r, nil @@ -107,7 +116,7 @@ func (tx *transaction) ExecuteStatement( onDone(r, err) }() - switch tx.state.Load() { + switch tx.state.Get() { case txStateCommitted: return nil, xerrors.WithStackTrace(errTxAlreadyCommitted) case txStateRollbacked: @@ -119,7 +128,7 @@ func (tx *transaction) ExecuteStatement( } if tx.control.Desc().GetCommitTx() { - tx.state.Store(txStateCommitted) + tx.state.Set(txStateCommitted) } return r, nil @@ -140,7 +149,7 @@ func (tx *transaction) CommitTx( onDone(err) }() - switch tx.state.Load() { + switch tx.state.Get() { case txStateCommitted: return nil, xerrors.WithStackTrace(errTxAlreadyCommitted) case txStateRollbacked: @@ -167,7 +176,7 @@ func (tx *transaction) CommitTx( } } - response, err = tx.s.tableService.CommitTransaction(ctx, request) + response, err = tx.s.client.CommitTransaction(ctx, request) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -177,7 +186,7 @@ func (tx *transaction) CommitTx( return nil, xerrors.WithStackTrace(err) } - tx.state.Store(txStateCommitted) + tx.state.Set(txStateCommitted) return scanner.NewUnary( nil, @@ -198,13 +207,13 @@ func (tx *transaction) Rollback(ctx context.Context) (err error) { onDone(err) }() - switch tx.state.Load() { + switch tx.state.Get() { case txStateCommitted: return nil // nop for committed tx case txStateRollbacked: return xerrors.WithStackTrace(errTxRollbackedEarly) default: - _, err = tx.s.tableService.RollbackTransaction(ctx, + _, err = tx.s.client.RollbackTransaction(ctx, &Ydb_Table.RollbackTransactionRequest{ SessionId: tx.s.id, TxId: tx.id, @@ -220,7 +229,7 @@ func (tx *transaction) Rollback(ctx context.Context) (err error) { return xerrors.WithStackTrace(err) } - tx.state.Store(txStateRollbacked) + tx.state.Set(txStateRollbacked) return nil } diff --git a/log/driver.go b/log/driver.go index fb9c8b1a1..1b01841e0 100644 --- a/log/driver.go +++ b/log/driver.go @@ -109,22 +109,22 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "dial") - endpoint := info.Endpoint + address := info.Address l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), ) start := time.Now() return func(info trace.DriverConnDialDoneInfo) { if info.Error == nil { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), ) } else { l.Log(WithLevel(ctx, WARN), "failed", Error(info.Error), - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), versionField(), ) @@ -136,16 +136,16 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(context.Background(), TRACE, "ydb", "driver", "conn", "state", "change") - endpoint := info.Endpoint + address := info.Address l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), Stringer("state", info.State), ) start := time.Now() return func(info trace.DriverConnStateChangeDoneInfo) { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), Stringer("state", info.State), ) @@ -156,22 +156,22 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "close") - endpoint := info.Endpoint + address := info.Address l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), ) start := time.Now() return func(info trace.DriverConnCloseDoneInfo) { if info.Error == nil { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), ) } else { l.Log(WithLevel(ctx, WARN), "failed", Error(info.Error), - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), versionField(), ) @@ -183,10 +183,10 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "invoke") - endpoint := info.Endpoint + address := info.Address method := string(info.Method) l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), String("method", method), ) start := time.Now() @@ -194,7 +194,7 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return func(info trace.DriverConnInvokeDoneInfo) { if info.Error == nil { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), String("method", method), latencyField(start), Stringer("metadata", metadata(info.Metadata)), @@ -202,7 +202,7 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo } else { l.Log(WithLevel(ctx, WARN), "failed", Error(info.Error), - Stringer("endpoint", endpoint), + String("address", address), String("method", method), latencyField(start), Stringer("metadata", metadata(info.Metadata)), @@ -220,10 +220,10 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "stream", "New") - endpoint := info.Endpoint + address := info.Address method := string(info.Method) l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), String("method", method), ) start := time.Now() @@ -231,14 +231,14 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return func(info trace.DriverConnNewStreamDoneInfo) { if info.Error == nil { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), String("method", method), latencyField(start), ) } else { l.Log(WithLevel(ctx, WARN), "failed", Error(info.Error), - Stringer("endpoint", endpoint), + String("address", address), String("method", method), latencyField(start), versionField(), @@ -319,17 +319,17 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "ban") - endpoint := info.Endpoint + address := info.Address cause := info.Cause l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), NamedError("cause", cause), ) start := time.Now() return func(info trace.DriverConnBanDoneInfo) { l.Log(WithLevel(ctx, WARN), "done", - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), Stringer("state", info.State), NamedError("cause", cause), @@ -342,15 +342,15 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo return nil } ctx := with(*info.Context, TRACE, "ydb", "driver", "conn", "allow") - endpoint := info.Endpoint + address := info.Address l.Log(ctx, "start", - Stringer("endpoint", endpoint), + String("address", address), ) start := time.Now() return func(info trace.DriverConnAllowDoneInfo) { l.Log(ctx, "done", - Stringer("endpoint", endpoint), + String("address", address), latencyField(start), Stringer("state", info.State), ) @@ -439,7 +439,7 @@ func internalDriver(l Logger, d trace.Detailer) trace.Driver { //nolint:gocyclo if info.Error == nil { l.Log(ctx, "done", latencyField(start), - Stringer("endpoint", info.Endpoint), + String("address", info.Address), ) } else { l.Log(WithLevel(ctx, ERROR), "failed", diff --git a/log/query.go b/log/query.go index e44e9e3ac..c52a68d5b 100644 --- a/log/query.go +++ b/log/query.go @@ -413,9 +413,7 @@ func internalQuery( } ctx := with(*info.Context, TRACE, "ydb", "query", "transaction", "execute") l.Log(ctx, "start", - String("SessionID", info.Session.ID()), String("TransactionID", info.Tx.ID()), - String("SessionStatus", info.Session.Status()), ) start := time.Now() diff --git a/metrics/node_id.go b/metrics/node_id.go index 176eace55..a06fdba6c 100644 --- a/metrics/node_id.go +++ b/metrics/node_id.go @@ -4,6 +4,6 @@ import ( "strconv" ) -func idToString(id uint32) string { - return strconv.FormatUint(uint64(id), 10) +func idToString(id int64) string { + return strconv.FormatInt(id, 10) } diff --git a/query/session.go b/query/session.go index 14da2f3d5..d53bd3854 100644 --- a/query/session.go +++ b/query/session.go @@ -78,6 +78,6 @@ func WithStatsMode(mode options.StatsMode) options.StatsModeOption { return options.WithStatsMode(mode) } -func WithCallOptions(opts ...grpc.CallOption) options.CallOptions { +func WithCallOptions(opts ...grpc.CallOption) options.GrpcOpts { return options.WithCallOptions(opts...) } diff --git a/table/options/options.go b/table/options/options.go index 922333a17..0a74c07fd 100644 --- a/table/options/options.go +++ b/table/options/options.go @@ -2,10 +2,12 @@ package options import ( "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Table" "google.golang.org/grpc" "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" ) @@ -834,22 +836,16 @@ type ( ExecuteDataQueryDesc struct { *Ydb_Table.ExecuteDataQueryRequest - IgnoreTruncated bool + IgnoreTruncated bool + WithQueryService bool } ExecuteDataQueryOption interface { + options.ExecuteOption + ApplyExecuteDataQueryOption(d *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption } - executeDataQueryOptionFunc func(d *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption ) -func (f executeDataQueryOptionFunc) ApplyExecuteDataQueryOption( - d *ExecuteDataQueryDesc, a *allocator.Allocator, -) []grpc.CallOption { - return f(d, a) -} - -var _ ExecuteDataQueryOption = executeDataQueryOptionFunc(nil) - type ( CommitTransactionDesc Ydb_Table.CommitTransactionRequest CommitTransactionOption func(*CommitTransactionDesc) @@ -871,6 +867,10 @@ func WithKeepInCache(keepInCache bool) ExecuteDataQueryOption { type withCallOptions []grpc.CallOption +func (opts withCallOptions) ApplyExecuteOption(s *options.Execute) { + s.GrpcCallOptions = append(s.GrpcCallOptions, opts...) +} + func (opts withCallOptions) ApplyExecuteScanQueryOption(d *ExecuteScanQueryDesc) []grpc.CallOption { return opts } @@ -890,22 +890,59 @@ func WithCallOptions(opts ...grpc.CallOption) withCallOptions { return opts } +type withCommitExecuteDataQueryOption struct{} + +func (withCommitExecuteDataQueryOption) ApplyExecuteOption(s *options.Execute) { + s.TxControl.Commit = true +} + +func (withCommitExecuteDataQueryOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + d.TxControl.CommitTx = true + + return nil +} + // WithCommit appends flag of commit transaction with executing query func WithCommit() ExecuteDataQueryOption { - return executeDataQueryOptionFunc(func(desc *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption { - desc.TxControl.CommitTx = true + return withCommitExecuteDataQueryOption{} +} - return nil - }) +type useQueryServiceExecuteOption struct{} + +func (useQueryServiceExecuteOption) ApplyExecuteOption(s *options.Execute) { +} + +func (u useQueryServiceExecuteOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + d.WithQueryService = true + + return nil +} + +// WithQueryService redirects request to query service client +func WithQueryService() ExecuteDataQueryOption { + return useQueryServiceExecuteOption{} +} + +type withIgnoreTruncatedOption struct{} + +func (withIgnoreTruncatedOption) ApplyExecuteOption(s *options.Execute) { +} + +func (w withIgnoreTruncatedOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + d.IgnoreTruncated = true + + return nil } // WithIgnoreTruncated mark truncated result as good (without error) func WithIgnoreTruncated() ExecuteDataQueryOption { - return executeDataQueryOptionFunc(func(desc *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption { - desc.IgnoreTruncated = true - - return nil - }) + return withIgnoreTruncatedOption{} } // WithQueryCachePolicyKeepInCache manages keep-in-cache policy @@ -932,20 +969,29 @@ func WithQueryCachePolicy(opts ...QueryCachePolicyOption) ExecuteDataQueryOption return withQueryCachePolicy(opts...) } -func withQueryCachePolicy(opts ...QueryCachePolicyOption) ExecuteDataQueryOption { - return executeDataQueryOptionFunc(func(d *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption { - if d.QueryCachePolicy == nil { - d.QueryCachePolicy = a.TableQueryCachePolicy() - d.QueryCachePolicy.KeepInCache = true - } - for _, opt := range opts { - if opt != nil { - opt((*queryCachePolicy)(d.QueryCachePolicy), a) - } +type withQueryCachePolicyOption []QueryCachePolicyOption + +func (opts withQueryCachePolicyOption) ApplyExecuteOption(s *options.Execute) { +} + +func (opts withQueryCachePolicyOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + if d.QueryCachePolicy == nil { + d.QueryCachePolicy = a.TableQueryCachePolicy() + d.QueryCachePolicy.KeepInCache = true + } + for _, opt := range opts { + if opt != nil { + opt((*queryCachePolicy)(d.QueryCachePolicy), a) } + } - return nil - }) + return nil +} + +func withQueryCachePolicy(opts ...QueryCachePolicyOption) ExecuteDataQueryOption { + return withQueryCachePolicyOption(opts) } func WithCommitCollectStatsModeNone() CommitTransactionOption { @@ -960,20 +1006,40 @@ func WithCommitCollectStatsModeBasic() CommitTransactionOption { } } +type withCollectStatsModeNoneOption struct{} + +func (withCollectStatsModeNoneOption) ApplyExecuteOption(s *options.Execute) { + s.StatsMode = options.StatsMode(Ydb_Query.StatsMode_STATS_MODE_NONE) +} + +func (withCollectStatsModeNoneOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + d.CollectStats = Ydb_Table.QueryStatsCollection_STATS_COLLECTION_NONE + + return nil +} + func WithCollectStatsModeNone() ExecuteDataQueryOption { - return executeDataQueryOptionFunc(func(d *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption { - d.CollectStats = Ydb_Table.QueryStatsCollection_STATS_COLLECTION_NONE + return withCollectStatsModeNoneOption{} +} - return nil - }) +type withCollectStatsModeBasicOption struct{} + +func (withCollectStatsModeBasicOption) ApplyExecuteOption(s *options.Execute) { + s.StatsMode = options.StatsMode(Ydb_Query.StatsMode_STATS_MODE_BASIC) } -func WithCollectStatsModeBasic() ExecuteDataQueryOption { - return executeDataQueryOptionFunc(func(d *ExecuteDataQueryDesc, a *allocator.Allocator) []grpc.CallOption { - d.CollectStats = Ydb_Table.QueryStatsCollection_STATS_COLLECTION_BASIC +func (withCollectStatsModeBasicOption) ApplyExecuteDataQueryOption( + d *ExecuteDataQueryDesc, a *allocator.Allocator, +) []grpc.CallOption { + d.CollectStats = Ydb_Table.QueryStatsCollection_STATS_COLLECTION_BASIC - return nil - }) + return nil +} + +func WithCollectStatsModeBasic() ExecuteDataQueryOption { + return withCollectStatsModeBasicOption{} } type ( diff --git a/tests/integration/table_with_query_service_test.go b/tests/integration/table_with_query_service_test.go new file mode 100644 index 000000000..90eaeac7e --- /dev/null +++ b/tests/integration/table_with_query_service_test.go @@ -0,0 +1,102 @@ +//go:build integration +// +build integration + +package integration + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/version" + "github.com/ydb-platform/ydb-go-sdk/v3/table" + "github.com/ydb-platform/ydb-go-sdk/v3/table/options" + "github.com/ydb-platform/ydb-go-sdk/v3/table/result/named" +) + +func TestTableWithQueryService(t *testing.T) { + if version.Lt(os.Getenv("YDB_VERSION"), "24.1") { + t.Skip("query service not allowed in YDB version '" + os.Getenv("YDB_VERSION") + "'") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db, err := ydb.Open(ctx, + os.Getenv("YDB_CONNECTION_STRING"), + ydb.WithAccessTokenCredentials(os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS")), + ) + require.NoError(t, err) + t.Run("table.Session.Execute", func(t *testing.T) { + var abc, def int32 + err = db.Table().Do(ctx, func(ctx context.Context, s table.Session) error { + _, res, err := s.Execute(ctx, table.DefaultTxControl(), + `SELECT 123 as abc, 456 as def;`, nil, + options.WithQueryService(), + ) + if err != nil { + return err + } + err = res.NextResultSetErr(ctx) + if err != nil { + return err + } + if !res.NextRow() { + if err = res.Err(); err != nil { + return err + } + return fmt.Errorf("unexpected empty result set") + } + err = res.ScanNamed( + named.Required("abc", &abc), + named.Required("def", &def), + ) + if err != nil { + return err + } + t.Log(abc, def) + return res.Err() + }, table.WithTxSettings(table.TxSettings(table.WithSnapshotReadOnly()))) + require.NoError(t, err) + require.EqualValues(t, 123, abc) + require.EqualValues(t, 456, def) + }) + t.Run("table.Transaction.Execute", func(t *testing.T) { + var abc, def int32 + err = db.Table().DoTx(ctx, func(ctx context.Context, tx table.TransactionActor) (err error) { + res, err := tx.Execute(ctx, + `SELECT 123 as abc, 456 as def;`, nil, + options.WithQueryService(), + ) + if err != nil { + return err + } + err = res.NextResultSetErr(ctx) + if err != nil { + return err + } + if !res.NextRow() { + if err = res.Err(); err != nil { + return err + } + return fmt.Errorf("unexpected empty result set") + } + err = res.ScanNamed( + named.Required("abc", &abc), + named.Required("def", &def), + ) + if err != nil { + return err + } + t.Log(abc, def) + return res.Err() + }, table.WithTxSettings(table.TxSettings(table.WithSnapshotReadOnly()))) + require.NoError(t, err) + require.EqualValues(t, 123, abc) + require.EqualValues(t, 456, def) + }) +} diff --git a/testutil/driver.go b/testutil/driver.go index 48b82d4ca..e55cd6975 100644 --- a/testutil/driver.go +++ b/testutil/driver.go @@ -138,7 +138,7 @@ type balancerStub struct { ) (grpc.ClientStream, error) } -func (b *balancerStub) HasNode(id uint32) bool { +func (b *balancerStub) HasNode(id int64) bool { return true } diff --git a/trace/driver.go b/trace/driver.go index 8cf5dedc3..aa9fb496e 100644 --- a/trace/driver.go +++ b/trace/driver.go @@ -159,10 +159,10 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo - State ConnState + Context *context.Context + Call call + Address string + State ConnState } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnStateChangeDoneInfo struct { @@ -274,9 +274,9 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo + Context *context.Context + Call call + Address string } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnDialDoneInfo struct { @@ -288,9 +288,9 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo + Context *context.Context + Call call + Address string } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnParkDoneInfo struct { @@ -302,9 +302,9 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo + Context *context.Context + Call call + Address string } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnCloseDoneInfo struct { @@ -316,11 +316,11 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo - State ConnState - Cause error + Context *context.Context + Call call + Address string + State ConnState + Cause error } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnBanDoneInfo struct { @@ -332,10 +332,10 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo - State ConnState + Context *context.Context + Call call + Address string + State ConnState } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnAllowDoneInfo struct { @@ -347,10 +347,10 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo - Method Method + Context *context.Context + Call call + Address string + Method Method } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnInvokeDoneInfo struct { @@ -366,10 +366,10 @@ type ( // Pointer to context provide replacement of context in trace callback function. // Warning: concurrent access to pointer on client side must be excluded. // Safe replacement of context are provided only inside callback function - Context *context.Context - Call call - Endpoint EndpointInfo - Method Method + Context *context.Context + Call call + Address string + Method Method } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverConnNewStreamDoneInfo struct { @@ -467,8 +467,8 @@ type ( } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverBalancerChooseEndpointDoneInfo struct { - Endpoint EndpointInfo - Error error + Address string + Error error } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals DriverRepeaterWakeUpStartInfo struct { diff --git a/trace/driver_gtrace.go b/trace/driver_gtrace.go index 50491225a..ccffb49c2 100644 --- a/trace/driver_gtrace.go +++ b/trace/driver_gtrace.go @@ -1314,11 +1314,11 @@ func DriverOnResolve(t *Driver, call call, target string, resolved []string) fun } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnStateChange(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { +func DriverOnConnStateChange(t *Driver, c *context.Context, call call, address string, state ConnState) func(state ConnState) { var p DriverConnStateChangeStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address p.State = state res := t.onConnStateChange(p) return func(state ConnState) { @@ -1328,11 +1328,11 @@ func DriverOnConnStateChange(t *Driver, c *context.Context, call call, endpoint } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnInvoke(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(_ error, issues []Issue, opID string, state ConnState, metadata map[string][]string) { +func DriverOnConnInvoke(t *Driver, c *context.Context, call call, address string, m Method) func(_ error, issues []Issue, opID string, state ConnState, metadata map[string][]string) { var p DriverConnInvokeStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address p.Method = m res := t.onConnInvoke(p) return func(e error, issues []Issue, opID string, state ConnState, metadata map[string][]string) { @@ -1346,11 +1346,11 @@ func DriverOnConnInvoke(t *Driver, c *context.Context, call call, endpoint Endpo } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnNewStream(t *Driver, c *context.Context, call call, endpoint EndpointInfo, m Method) func(_ error, state ConnState) { +func DriverOnConnNewStream(t *Driver, c *context.Context, call call, address string, m Method) func(_ error, state ConnState) { var p DriverConnNewStreamStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address p.Method = m res := t.onConnNewStream(p) return func(e error, state ConnState) { @@ -1397,11 +1397,11 @@ func DriverOnConnStreamCloseSend(t *Driver, c *context.Context, call call) func( } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnDial(t *Driver, c *context.Context, call call, endpoint EndpointInfo) func(error) { +func DriverOnConnDial(t *Driver, c *context.Context, call call, address string) func(error) { var p DriverConnDialStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address res := t.onConnDial(p) return func(e error) { var p DriverConnDialDoneInfo @@ -1410,11 +1410,11 @@ func DriverOnConnDial(t *Driver, c *context.Context, call call, endpoint Endpoin } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnBan(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState, cause error) func(state ConnState) { +func DriverOnConnBan(t *Driver, c *context.Context, call call, address string, state ConnState, cause error) func(state ConnState) { var p DriverConnBanStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address p.State = state p.Cause = cause res := t.onConnBan(p) @@ -1425,11 +1425,11 @@ func DriverOnConnBan(t *Driver, c *context.Context, call call, endpoint Endpoint } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnAllow(t *Driver, c *context.Context, call call, endpoint EndpointInfo, state ConnState) func(state ConnState) { +func DriverOnConnAllow(t *Driver, c *context.Context, call call, address string, state ConnState) func(state ConnState) { var p DriverConnAllowStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address p.State = state res := t.onConnAllow(p) return func(state ConnState) { @@ -1439,11 +1439,11 @@ func DriverOnConnAllow(t *Driver, c *context.Context, call call, endpoint Endpoi } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnPark(t *Driver, c *context.Context, call call, endpoint EndpointInfo) func(error) { +func DriverOnConnPark(t *Driver, c *context.Context, call call, address string) func(error) { var p DriverConnParkStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address res := t.onConnPark(p) return func(e error) { var p DriverConnParkDoneInfo @@ -1452,11 +1452,11 @@ func DriverOnConnPark(t *Driver, c *context.Context, call call, endpoint Endpoin } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnConnClose(t *Driver, c *context.Context, call call, endpoint EndpointInfo) func(error) { +func DriverOnConnClose(t *Driver, c *context.Context, call call, address string) func(error) { var p DriverConnCloseStartInfo p.Context = c p.Call = call - p.Endpoint = endpoint + p.Address = address res := t.onConnClose(p) return func(e error) { var p DriverConnCloseDoneInfo @@ -1504,14 +1504,14 @@ func DriverOnBalancerClose(t *Driver, c *context.Context, call call) func(error) } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func DriverOnBalancerChooseEndpoint(t *Driver, c *context.Context, call call) func(endpoint EndpointInfo, _ error) { +func DriverOnBalancerChooseEndpoint(t *Driver, c *context.Context, call call) func(address string, _ error) { var p DriverBalancerChooseEndpointStartInfo p.Context = c p.Call = call res := t.onBalancerChooseEndpoint(p) - return func(endpoint EndpointInfo, e error) { + return func(address string, e error) { var p DriverBalancerChooseEndpointDoneInfo - p.Endpoint = endpoint + p.Address = address p.Error = e res(p) } diff --git a/trace/query.go b/trace/query.go index 3fbaef522..adca7dee9 100644 --- a/trace/query.go +++ b/trace/query.go @@ -135,9 +135,8 @@ type ( Context *context.Context Call call - Session querySessionInfo - Tx queryTransactionInfo - Query string + Tx queryTransactionInfo + Query string } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals QueryTxExecuteDoneInfo struct { diff --git a/trace/query_gtrace.go b/trace/query_gtrace.go index 2bbc46f3c..0bf3cc7aa 100644 --- a/trace/query_gtrace.go +++ b/trace/query_gtrace.go @@ -1460,11 +1460,10 @@ func QueryOnSessionBegin(t *Query, c *context.Context, call call, session queryS } } // Internals: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#internals -func QueryOnTxExecute(t *Query, c *context.Context, call call, session querySessionInfo, tx queryTransactionInfo, query string) func(error) { +func QueryOnTxExecute(t *Query, c *context.Context, call call, tx queryTransactionInfo, query string) func(error) { var p QueryTxExecuteStartInfo p.Context = c p.Call = call - p.Session = session p.Tx = tx p.Query = query res := t.onTxExecute(p)