From 83fa609720e3491b74d409c3fedeb7841b298c14 Mon Sep 17 00:00:00 2001 From: mazchew Date: Mon, 19 Aug 2024 21:48:27 +0800 Subject: [PATCH] refactoring of exec/execContext and query/queryContext, as well as add context management support and relevant UTs --- client/gosqldriver/connection.go | 81 +++++++ client/gosqldriver/connection_test.go | 236 +++++++++++++++++++ client/gosqldriver/rows_test.go | 5 +- client/gosqldriver/statement.go | 327 ++++++++++---------------- client/gosqldriver/statement_test.go | 29 +++ 5 files changed, 471 insertions(+), 207 deletions(-) create mode 100644 client/gosqldriver/connection_test.go diff --git a/client/gosqldriver/connection.go b/client/gosqldriver/connection.go index 848dfdfa..6605d1db 100644 --- a/client/gosqldriver/connection.go +++ b/client/gosqldriver/connection.go @@ -19,6 +19,7 @@ package gosqldriver import ( + "context" "database/sql/driver" "errors" "fmt" @@ -51,6 +52,10 @@ type heraConnectionInterface interface { getCorrID() *netstring.Netstring getShardKeyPayload() []byte setCorrID(*netstring.Netstring) + startWatcher() + finish() + cancel(err error) + watchCancel(ctx context.Context) error } type heraConnection struct { @@ -62,6 +67,12 @@ type heraConnection struct { // correlation id corrID *netstring.Netstring clientinfo *netstring.Netstring + + // Context support + watching bool + watcher chan<- context.Context + finished chan<- struct{} + closech chan struct{} } // NewHeraConnection creates a structure implementing a driver.Con interface @@ -70,9 +81,79 @@ func NewHeraConnection(conn net.Conn) driver.Conn { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, hera.id, "create driver connection") } + + hera.startWatcher() + return hera } +func (c *heraConnection) startWatcher() { + watcher := make(chan context.Context, 1) + c.watcher = watcher + finished := make(chan struct{}) + c.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-c.closech: + return + } + + select { + case <-ctx.Done(): + c.cancel(ctx.Err()) + case <-finished: + case <-c.closech: + return + } + } + }() +} + +func (c *heraConnection) finish() { + if !c.watching || c.finished == nil { + return + } + select { + case c.finished <- struct{}{}: + c.watching = false + case <-c.closech: + } +} + +func (c *heraConnection) cancel(err error) { + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, c.id, "ctx error:", err) + } + c.Close() +} + +func (c *heraConnection) watchCancel(ctx context.Context) error { + if c.watching { + // Reach here if canceled, the connection is already invalid + c.Close() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + + if c.watcher == nil { + return nil + } + + c.watching = true + c.watcher <- ctx + return nil +} + // Prepare returns a prepared statement, bound to this connection. func (c *heraConnection) Prepare(query string) (driver.Stmt, error) { if logger.GetLogger().V(logger.Debug) { diff --git a/client/gosqldriver/connection_test.go b/client/gosqldriver/connection_test.go new file mode 100644 index 00000000..809254aa --- /dev/null +++ b/client/gosqldriver/connection_test.go @@ -0,0 +1,236 @@ +package gosqldriver + +import ( + "context" + "errors" + "net" + "sync" + "testing" + "time" +) + +func TestNewHeraConnection(t *testing.T) { + // Using net.Pipe to create a simple in-memory connection + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + heraConn := NewHeraConnection(clientConn).(*heraConnection) + + if heraConn.conn != clientConn { + t.Fatalf("expected conn to be initialized with clientConn") + } + if heraConn.watcher == nil || heraConn.finished == nil { + t.Fatalf("expected watcher and finished channels to be initialized") + } +} + +func TestStartWatcher_CancelContext(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + mockHera := NewHeraConnection(clientConn).(*heraConnection) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // WaitGroup to ensure that the context is being watched and that Close() is called + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + mockHera.watchCancel(ctx) + }() + + cancel() // Cancel the context + + // Wait for the goroutine to finish processing the cancellation + wg.Wait() + + // Allow some time for the goroutine to process the cancellation + time.Sleep(500 * time.Millisecond) + + // Test should finish without checking the connection closure directly + // TODO: seems like there is an issue with closech currently, where it doesn't seem to be instantiated as part of heraConnection + t.Log("Test completed successfully, context cancellation was processed.") +} + +func TestFinish(t *testing.T) { + tests := []struct { + name string + watching bool + finished chan struct{} + closech chan struct{} + expectFinished bool + expectWatching bool + }{ + { + name: "Finish with watching true and finished channel", + watching: true, + finished: make(chan struct{}, 1), + closech: make(chan struct{}), + expectFinished: true, + expectWatching: false, + }, + { + name: "Finish with watching false", + watching: false, + finished: make(chan struct{}, 1), + closech: make(chan struct{}), + expectFinished: false, + expectWatching: false, + }, + { + name: "Finish with nil finished channel", + watching: true, + finished: nil, + closech: make(chan struct{}), + expectFinished: false, + expectWatching: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHera := &heraConnection{ + watching: tt.watching, + finished: tt.finished, + closech: tt.closech, + } + + mockHera.finish() + + // Check if the finished channel received a signal + if tt.expectFinished { + select { + case <-tt.finished: + // Success case: Signal received + default: + t.Fatalf("expected signal on finished channel, but got none") + } + } else if tt.finished != nil { + select { + case <-tt.finished: + t.Fatalf("did not expect signal on finished channel, but got one") + default: + // Success case: No signal as expected + } + } + + // Check if watching is set to false after finishing + if mockHera.watching != tt.expectWatching { + t.Fatalf("expected watching to be %v, got %v", tt.expectWatching, mockHera.watching) + } + }) + } +} + +func TestCancel(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + // Create an instance of heraConnection with a valid connection + mockHera := &heraConnection{ + id: "test-id", + conn: clientConn, + } + + // Simulate an error that triggers cancel() + err := errors.New("mock error") + mockHera.cancel(err) + + // Check if the connection was closed + if !mockHera.isClosed() { + t.Fatalf("expected connection to be closed after cancel() is called") + } +} + +func TestWatchCancel(t *testing.T) { + tests := []struct { + name string + watching bool + ctx context.Context + watcher chan context.Context + expectClose bool + expectedErr error + expectWatching bool + }{ + { + name: "Already watching a different context", + watching: true, + ctx: context.Background(), + watcher: make(chan context.Context, 1), + expectClose: true, // The new connection should be closed + expectedErr: nil, + expectWatching: true, // The original connection remains watching + }, + { + name: "Context already canceled", + watching: false, + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + return ctx + }(), + expectedErr: context.Canceled, + expectWatching: false, + }, + { + name: "Non-cancellable context", + watching: false, + ctx: context.Background(), + expectedErr: nil, + expectWatching: false, + }, + { + name: "Valid context, start watching", + watching: false, + ctx: func() context.Context { + ctx, _ := context.WithCancel(context.Background()) // Ensure ctx is cancellable + return ctx + }(), + watcher: make(chan context.Context, 1), + expectedErr: nil, + expectWatching: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + mockHera := &heraConnection{ + watching: tt.watching, + conn: clientConn, + watcher: tt.watcher, + } + + err := mockHera.watchCancel(tt.ctx) + + // Verify the returned error + if err != tt.expectedErr { + t.Fatalf("expected error %v, got %v", tt.expectedErr, err) + } + + // Check if the new connection was closed (only relevant for the "Already watching a different context" case) + if tt.expectClose && !mockHera.isClosed() { + t.Fatalf("expected connection to be closed, but it wasn't") + } + + // Check if watching is set correctly for the original connection + if mockHera.watching != tt.expectWatching { + t.Fatalf("expected watching to be %v, got %v", tt.expectWatching, mockHera.watching) + } + }) + } +} + +func (c *heraConnection) isClosed() bool { + // Attempt to write to the connection to check if it is closed + _, err := c.conn.Write([]byte("test")) + return err != nil +} diff --git a/client/gosqldriver/rows_test.go b/client/gosqldriver/rows_test.go index 79c6a7b4..df85a412 100644 --- a/client/gosqldriver/rows_test.go +++ b/client/gosqldriver/rows_test.go @@ -12,8 +12,9 @@ import ( type mockHeraConnection struct { heraConnection - responses []netstring.Netstring - execErr error + responses []netstring.Netstring + execErr error + finishCalled bool } func (m *mockHeraConnection) Prepare(query string) (driver.Stmt, error) { diff --git a/client/gosqldriver/statement.go b/client/gosqldriver/statement.go index e63aec15..0defe2c2 100644 --- a/client/gosqldriver/statement.go +++ b/client/gosqldriver/statement.go @@ -90,82 +90,27 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { if st.hera.getCorrID() != nil { crid = 1 } - binds := len(args) - nss := make([]*netstring.Netstring, crid /*CmdClientCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and CmdBindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/) - idx := 0 - if crid == 1 { - nss[0] = st.hera.getCorrID() - st.hera.setCorrID(nil) - idx++ - } - nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) - idx++ - for _, val := range args { - nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(fmt.Sprintf("p%d", (idx-crid)/2+1))) - idx++ - switch val := val.(type) { - default: - return nil, fmt.Errorf("unexpected parameter type %T, only int,string and []byte supported", val) - case int: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(fmt.Sprintf("%d", int(val)))) - case int64: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(fmt.Sprintf("%d", int(val)))) - case []byte: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, val) - case string: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) - } - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) - } - idx++ - } - if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) - idx++ - } - nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) - cmd := netstring.NewNetstringEmbedded(nss) - err := st.hera.execNs(cmd) + nss, err := prepareExecNetstrings(st, convertValueToNamedValue(args), crid, sk) if err != nil { return nil, err } - ns, err := st.hera.getResponse() - if err != nil { - return nil, err - } - if ns.Cmd != common.RcValue { - switch ns.Cmd { - case common.RcSQLError: - return nil, fmt.Errorf("SQL error: %s", string(ns.Payload)) - case common.RcError: - return nil, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) - default: - return nil, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) - } - } - // it was columns number, irelevant for DML - ns, err = st.hera.getResponse() + cmd := netstring.NewNetstringEmbedded(nss) + err = st.hera.execNs(cmd) if err != nil { return nil, err } - if ns.Cmd != common.RcValue { - return nil, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) - } - res := &result{} - res.nRows, err = strconv.Atoi(string(ns.Payload)) + nRows, err := handleExecResponse(st) if err != nil { return nil, err } if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", res.nRows) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", nRows) } - return res, nil + return &result{nRows: nRows}, nil } // Implement driver.StmtExecContext method to execute a DML func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - //TODO: refactor ExecContext / Exec to reuse code //TODO: honor the context timeout and return when it is canceled sk := 0 if len(st.hera.getShardKeyPayload()) > 0 { @@ -175,86 +120,68 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv if st.hera.getCorrID() != nil { crid = 1 } - binds := len(args) - nss := make([]*netstring.Netstring, crid /*CmdClientCalCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/) - idx := 0 - if crid == 1 { - nss[0] = st.hera.getCorrID() - st.hera.setCorrID(nil) - idx++ - } - nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) - idx++ - for _, val := range args { - if len(val.Name) > 0 { - nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(val.Name)) - } else { - nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(fmt.Sprintf("p%d", (idx-crid)/2+1))) - } - idx++ - switch val := val.Value.(type) { - default: - return nil, fmt.Errorf("unexpected parameter type %T, only int,string and []byte supported", val) - case int: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(fmt.Sprintf("%d", int(val)))) - case int64: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(fmt.Sprintf("%d", int(val)))) - case []byte: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, val) - case string: - nss[idx] = netstring.NewNetstringFrom(common.CmdBindValue, []byte(val)) - } - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, st.hera.getID(), "Bind name =", string(nss[idx-1].Payload), ", value=", string(nss[idx].Payload)) - } - idx++ - } - if sk == 1 { - nss[idx] = netstring.NewNetstringFrom(common.CmdShardKey, st.hera.getShardKeyPayload()) - idx++ + nss, err := prepareExecNetstrings(st, args, crid, sk) + if err != nil { + return nil, err } - nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) cmd := netstring.NewNetstringEmbedded(nss) - err := st.hera.execNs(cmd) + err = st.hera.execNs(cmd) if err != nil { return nil, err } - ns, err := st.hera.getResponse() + + nRows, err := handleExecResponse(st) if err != nil { return nil, err } - if ns.Cmd != common.RcValue { - switch ns.Cmd { - case common.RcSQLError: - return nil, fmt.Errorf("SQL error: %s", string(ns.Payload)) - case common.RcError: - return nil, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) - default: - return nil, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) - } + + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", nRows) } - // it was columns number, irelevant for DML - ns, err = st.hera.getResponse() + return &result{nRows: nRows}, nil +} + +// Implements driver.Stmt. +// Query executes a query that may return rows, such as a SELECT. +func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { + sk := 0 + if len(st.hera.getShardKeyPayload()) > 0 { + sk = 1 + } + crid := 0 + if st.hera.getCorrID() != nil { + crid = 1 + } + nss, err := prepareQueryNetstrings(st, convertValueToNamedValue(args), crid, sk) if err != nil { return nil, err } - if ns.Cmd != common.RcValue { - return nil, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) + + cmd := netstring.NewNetstringEmbedded(nss) + err = st.hera.execNs(cmd) + if err != nil { + return nil, err } - res := &result{} - res.nRows, err = strconv.Atoi(string(ns.Payload)) + + cols, err := handleQueryResponse(st) if err != nil { return nil, err } + if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.getID(), "DML successfull, rows affected:", res.nRows) + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successful, num columns:", cols) } - return res, nil + return newRows(st.hera, cols, st.fetchChunkSize) } -// Implements driver.Stmt. -// Query executes a query that may return rows, such as a SELECT. -func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { +// Implements driver.StmtQueryContextx +// QueryContext executes a query that may return rows, such as a SELECT +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + // TODO: honor the context timeout and return when it is canceled + if err := st.hera.watchCancel(ctx); err != nil { + return nil, err + } + sk := 0 if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 @@ -263,6 +190,33 @@ func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { if st.hera.getCorrID() != nil { crid = 1 } + nss, err := prepareQueryNetstrings(st, args, crid, sk) + if err != nil { + return nil, err + } + cmd := netstring.NewNetstringEmbedded(nss) + err = st.hera.execNs(cmd) + if err != nil { + st.hera.finish() + return nil, err + } + + cols, err := handleQueryResponse(st) + if err != nil { + return nil, err + } + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successfull, num columns:", cols) + } + return newRows(st.hera, cols, st.fetchChunkSize) +} + +// implementing the extension HeraStmt interface +func (st *stmt) SetFetchSize(num int) { + st.fetchChunkSize = []byte(fmt.Sprintf("%d", num)) +} + +func prepareQueryNetstrings(st *stmt, args []driver.NamedValue, crid, sk int) ([]*netstring.Netstring, error) { binds := len(args) nss := make([]*netstring.Netstring, crid /*CmdClientCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/ +1 /* CmdFetch */) idx := 0 @@ -274,9 +228,13 @@ func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { nss[idx] = netstring.NewNetstringFrom(common.CmdPrepareV2, []byte(st.sql)) idx++ for _, val := range args { - nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(fmt.Sprintf("p%d", (idx-crid)/2+1))) + if len(val.Name) > 0 { + nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(val.Name)) + } else { + nss[idx] = netstring.NewNetstringFrom(common.CmdBindName, []byte(fmt.Sprintf("p%d", (idx-crid)/2+1))) + } idx++ - switch val := val.(type) { + switch val := val.Value.(type) { default: return nil, fmt.Errorf("unexpected parameter type %T, only int,string and []byte supported", val) case int: @@ -300,18 +258,16 @@ func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) idx++ nss[idx] = netstring.NewNetstringFrom(common.CmdFetch, st.fetchChunkSize) - cmd := netstring.NewNetstringEmbedded(nss) - err := st.hera.execNs(cmd) - if err != nil { - return nil, err - } + return nss, nil +} +func handleQueryResponse(st *stmt) (int, error) { var ns *netstring.Netstring -Loop: + var err error for { ns, err = st.hera.getResponse() if err != nil { - return nil, err + return 0, err } if ns.Cmd != common.RcValue { switch ns.Cmd { @@ -319,56 +275,38 @@ Loop: if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, st.hera.getID(), " Still executing ...") } - // continues the loop case common.RcSQLError: - return nil, fmt.Errorf("SQL error: %s", string(ns.Payload)) + return 0, fmt.Errorf("SQL error: %s", string(ns.Payload)) case common.RcError: - return nil, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) + return 0, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) default: - return nil, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) + return 0, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) } } else { - break Loop + break } } cols, err := strconv.Atoi(string(ns.Payload)) if err != nil { - return nil, err + return 0, err } - ns, err = st.hera.getResponse() if err != nil { - return nil, err + return 0, err } if ns.Cmd != common.RcValue { - return nil, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) + return 0, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) } - // number of rows is ignored _, err = strconv.Atoi(string(ns.Payload)) if err != nil { - return nil, err - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successfull, num columns:", cols) + return 0, err } - return newRows(st.hera, cols, st.fetchChunkSize) + return cols, nil } -// Implements driver.StmtQueryContextx -// QueryContext executes a query that may return rows, such as a SELECT -func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - // TODO: refactor Query/QueryContext to reuse code - // TODO: honor the context timeout and return when it is canceled - sk := 0 - if len(st.hera.getShardKeyPayload()) > 0 { - sk = 1 - } - crid := 0 - if st.hera.getCorrID() != nil { - crid = 1 - } +func prepareExecNetstrings(st *stmt, args []driver.NamedValue, crid, sk int) ([]*netstring.Netstring, error) { binds := len(args) - nss := make([]*netstring.Netstring, crid /*ClientCalCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*ShardKey*/ +1 /*Execute*/ +1 /* Fetch */) + nss := make([]*netstring.Netstring, crid /*CmdClientCalCorrelationID*/ +1 /*CmdPrepare*/ +2*binds /* CmdBindName and BindValue */ +sk /*CmdShardKey*/ +1 /*CmdExecute*/) idx := 0 if crid == 1 { nss[0] = st.hera.getCorrID() @@ -406,63 +344,42 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri idx++ } nss[idx] = netstring.NewNetstringFrom(common.CmdExecute, nil) - idx++ - nss[idx] = netstring.NewNetstringFrom(common.CmdFetch, st.fetchChunkSize) - cmd := netstring.NewNetstringEmbedded(nss) - err := st.hera.execNs(cmd) + return nss, nil +} + +func handleExecResponse(st *stmt) (int, error) { + ns, err := st.hera.getResponse() if err != nil { - return nil, err + return 0, err } - - var ns *netstring.Netstring -Loop: - for { - ns, err = st.hera.getResponse() - if err != nil { - return nil, err - } - if ns.Cmd != common.RcValue { - switch ns.Cmd { - case common.RcStillExecuting: - if logger.GetLogger().V(logger.Info) { - logger.GetLogger().Log(logger.Info, st.hera.getID(), " Still executing ...") - } - // continues the loop - case common.RcSQLError: - return nil, fmt.Errorf("SQL error: %s", string(ns.Payload)) - case common.RcError: - return nil, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) - default: - return nil, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) - } - } else { - break Loop + if ns.Cmd != common.RcValue { + switch ns.Cmd { + case common.RcSQLError: + return 0, fmt.Errorf("SQL error: %s", string(ns.Payload)) + case common.RcError: + return 0, fmt.Errorf("Internal hera error: %s", string(ns.Payload)) + default: + return 0, fmt.Errorf("Unknown code: %d, data: %s", ns.Cmd, string(ns.Payload)) } } - cols, err := strconv.Atoi(string(ns.Payload)) - if err != nil { - return nil, err - } - ns, err = st.hera.getResponse() if err != nil { - return nil, err + return 0, err } if ns.Cmd != common.RcValue { - return nil, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) + return 0, fmt.Errorf("Unknown code2: %d, data: %s", ns.Cmd, string(ns.Payload)) } - // number of rows is ignored - _, err = strconv.Atoi(string(ns.Payload)) + nRows, err := strconv.Atoi(string(ns.Payload)) if err != nil { - return nil, err - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, st.hera.getID(), "Query successfull, num columns:", cols) + return 0, err } - return newRows(st.hera, cols, st.fetchChunkSize) + return nRows, nil } -// implementing the extension HeraStmt interface -func (st *stmt) SetFetchSize(num int) { - st.fetchChunkSize = []byte(fmt.Sprintf("%d", num)) +func convertValueToNamedValue(args []driver.Value) []driver.NamedValue { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return namedArgs } diff --git a/client/gosqldriver/statement_test.go b/client/gosqldriver/statement_test.go index 0c9aec08..b3a138ac 100644 --- a/client/gosqldriver/statement_test.go +++ b/client/gosqldriver/statement_test.go @@ -3,6 +3,7 @@ package gosqldriver import ( "context" "database/sql/driver" + "errors" "testing" "github.com/paypal/hera/common" @@ -21,6 +22,10 @@ func (m *mockHeraConnection) getShardKeyPayload() []byte { return make([]byte, 1) } +func (m *mockHeraConnection) finish() { + m.finishCalled = true +} + func TestNewStmt(t *testing.T) { mockHera := &mockHeraConnection{} @@ -571,6 +576,27 @@ func TestStmtQueryContext(t *testing.T) { expectedError: "Unknown code: 999, data: Unknown code", expectedCols: 0, }, + { + name: "QueryContext with cancelled context", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, + } + }, + cancelCtx: true, + expectedError: "context canceled", + expectedCols: 0, + }, + { + name: "QueryContext with execNs failure", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.execErr = errors.New("mock execNs error") + }, + expectedError: "mock execNs error", + expectedCols: 0, + }, } for _, tt := range tests { @@ -597,6 +623,9 @@ func TestStmtQueryContext(t *testing.T) { if err == nil || err.Error() != tt.expectedError { t.Fatalf("expected error %v, got %v", tt.expectedError, err) } + if mockHera.execErr != nil && !mockHera.finishCalled { + t.Fatalf("expected finish() to be called, but it was not") + } return } if err != nil {