From 9fa33e2d5ae45e59beb296914e7b6939e8c8dc68 Mon Sep 17 00:00:00 2001 From: Michael Hobbs Date: Wed, 1 Sep 2021 16:16:18 -0700 Subject: [PATCH] implement ConnPrepareContext/StmtQueryContext/StmtExecContext interfaces --- conn.go | 4 ++ conn_go18.go | 79 +++++++++++++++++++++++- conn_test.go | 164 +++++++++++++++++++++++++++++++++++++++++++++++++ issues_test.go | 36 ++++++++++- 4 files changed, 281 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index b09a17047..8e445f32c 100644 --- a/conn.go +++ b/conn.go @@ -1360,6 +1360,10 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { + return st.query(v) +} + +func (st *stmt) query(v []driver.Value) (r *rows, err error) { if st.cn.getBad() { return nil, driver.ErrBadConn } diff --git a/conn_go18.go b/conn_go18.go index 2b9a9599e..3c83082b3 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -11,6 +11,10 @@ import ( "time" ) +const ( + watchCancelDialContextTimeout = time.Second * 10 +) + // Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { list := make([]driver.Value, len(args)) @@ -43,6 +47,14 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam return cn.Exec(query, list) } +// Implement the "ConnPrepareContext" interface +func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + return cn.Prepare(query) +} + // Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { var mode string @@ -109,7 +121,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() { // so it must not be used for the additional network // request to cancel the query. // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) defer cancel() _ = cn.cancel(ctxCancel) @@ -172,3 +184,68 @@ func (cn *conn) cancel(ctx context.Context) error { return err } } + +// Implement the "StmtQueryContext" interface +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := st.watchCancel(ctx) + r, err := st.query(list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "StmtExecContext" interface +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := st.watchCancel(ctx); finish != nil { + defer finish() + } + + return st.Exec(list) +} + +// watchCancel is implemented on stmt in order to not mark the parent conn as bad +func (st *stmt) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) + defer cancel() + + _ = st.cancel(ctxCancel) + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (st *stmt) cancel(ctx context.Context) error { + return st.cn.cancel(ctx) +} diff --git a/conn_test.go b/conn_test.go index a05d81d0b..4ac3d2b82 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1806,3 +1806,167 @@ func TestCopyInStmtAffectedRows(t *testing.T) { res.RowsAffected() res.LastInsertId() } + +func TestConnPrepareContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + err error + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT 1", + err: nil, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Microsecond) + }, + sql: "SELECT 1", + err: context.DeadlineExceeded, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT 1", + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + _, err := db.PrepareContext(ctx, tt.sql) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + } + }) + } +} + +func TestStmtQueryContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + err error + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT pg_sleep(1);", + err: nil, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 1*time.Second) + }, + sql: "SELECT pg_sleep(10);", + err: &Error{Message: "canceling statement due to user request"}, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT pg_sleep(1);", + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + stmt, err := db.PrepareContext(ctx, tt.sql) + if err != nil { + t.Fatal(err) + } + _, err = stmt.QueryContext(ctx) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + } + }) + } +} + +func TestStmtExecContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + err error + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT pg_sleep(1);", + err: nil, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 1*time.Second) + }, + sql: "SELECT pg_sleep(10);", + err: &Error{Message: "canceling statement due to user request"}, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT pg_sleep(1);", + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + stmt, err := db.PrepareContext(ctx, tt.sql) + if err != nil { + t.Fatal(err) + } + _, err = stmt.ExecContext(ctx) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + } + }) + } +} diff --git a/issues_test.go b/issues_test.go index 3a330a0a9..55d3f1ec3 100644 --- a/issues_test.go +++ b/issues_test.go @@ -1,6 +1,10 @@ package pq -import "testing" +import ( + "context" + "testing" + "time" +) func TestIssue494(t *testing.T) { db := openTestConn(t) @@ -24,3 +28,33 @@ func TestIssue494(t *testing.T) { t.Fatal("expected error") } } + +func TestIssue1046(t *testing.T) { + ctxTimeout := time.Second * 2 + + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`) + if err != nil { + t.Fatal(err) + } + + var d []uint8 + err = stmt.QueryRowContext(ctx).Scan(&d) + dl, _ := ctx.Deadline() + since := time.Since(dl) + if since > ctxTimeout { + t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since) + t.Fail() + } + expectedErr := &Error{Message: "canceling statement due to user request"} + if err == nil || err.Error() != expectedErr.Error() { + t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err()) + t.Logf("got err: [%T] %+v expected err: [%T] %+v", err, err, expectedErr, expectedErr) + t.Fail() + } +}