diff --git a/client/gosqldriver/statement.go b/client/gosqldriver/statement.go index 0defe2c2..c54e8f36 100644 --- a/client/gosqldriver/statement.go +++ b/client/gosqldriver/statement.go @@ -112,6 +112,11 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { // Implement driver.StmtExecContext method to execute a DML func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { //TODO: honor the context timeout and return when it is canceled + if err := st.hera.watchCancel(ctx); err != nil { + fmt.Println("=== reach here 1 ====") + return nil, err + } + sk := 0 if len(st.hera.getShardKeyPayload()) > 0 { sk = 1 @@ -127,6 +132,8 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv cmd := netstring.NewNetstringEmbedded(nss) err = st.hera.execNs(cmd) if err != nil { + fmt.Println("=== reach here 2 ====") + st.hera.finish() return nil, err } diff --git a/client/gosqldriver/statement_test.go b/client/gosqldriver/statement_test.go index b3a138ac..ccf2fdde 100644 --- a/client/gosqldriver/statement_test.go +++ b/client/gosqldriver/statement_test.go @@ -295,6 +295,27 @@ func TestStmtExecContext(t *testing.T) { expectedError: "Unknown code: 999, data: Unknown code", expectedRows: 0, }, + { + name: "ExecContext with cancelled context", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.responses = []netstring.Netstring{ + {Cmd: common.RcValue, Payload: []byte("2")}, + } + }, + cancelCtx: true, + expectedError: "context canceled", + expectedRows: 0, + }, + { + name: "ExecContext with execNs failure", + args: []driver.NamedValue{{Ordinal: 1, Value: 1}}, + setupMock: func(mock *mockHeraConnection) { + mock.execErr = errors.New("mock execNs error") + }, + expectedError: "mock execNs error", + expectedRows: 0, + }, } for _, tt := range tests {