diff --git a/conn.go b/conn.go index 8e445f32..e050d535 100644 --- a/conn.go +++ b/conn.go @@ -18,7 +18,7 @@ import ( "path/filepath" "strconv" "strings" - "sync/atomic" + "sync" "time" "unicode" @@ -140,9 +140,10 @@ type conn struct { saveMessageType byte saveMessageBuffer []byte - // If true, this connection is bad and all public-facing functions should - // return ErrBadConn. - bad *atomic.Value + // If an error is set, this connection is bad and all public-facing + // functions should return the appropriate error by calling get() + // (ErrBadConn) or getForNext(). + err syncErr // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -166,6 +167,40 @@ type conn struct { gss GSS } +type syncErr struct { + err error + sync.Mutex +} + +// Return ErrBadConn if connection is bad. +func (e *syncErr) get() error { + e.Lock() + defer e.Unlock() + if e.err != nil { + return driver.ErrBadConn + } + return nil +} + +// Return the error set on the connection. Currently only used by rows.Next. +func (e *syncErr) getForNext() error { + e.Lock() + defer e.Unlock() + return e.err +} + +// Set error, only if it isn't set yet. +func (e *syncErr) set(err error) { + if err == nil { + panic("attempt to set nil err") + } + e.Lock() + defer e.Unlock() + if e.err == nil { + e.err = err + } +} + // Handle driver-side settings in parsed connection string. func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { @@ -306,12 +341,9 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { o[k] = v } - bad := &atomic.Value{} - bad.Store(false) cn = &conn{ opts: o, dialer: c.dialer, - bad: bad, } err = cn.handleDriverSettings(o) if err != nil { @@ -516,22 +548,9 @@ func (cn *conn) isInTransaction() bool { cn.txnStatus == txnStatusInFailedTransaction } -func (cn *conn) setBad() { - if cn.bad != nil { - cn.bad.Store(true) - } -} - -func (cn *conn) getBad() bool { - if cn.bad != nil { - return cn.bad.Load().(bool) - } - return false -} - func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -541,8 +560,8 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.getBad() { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -552,11 +571,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.setBad() + cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.setBad() + cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -570,8 +589,8 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.getBad() { - return driver.ErrBadConn + if err := cn.err.get(); err != nil { + return err } defer cn.errRecover(&err) @@ -592,12 +611,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.setBad() + cn.err.set(driver.ErrBadConn) } return err } if commandTag != "COMMIT" { - cn.setBad() + cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -606,8 +625,8 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.getBad() { - return driver.ErrBadConn + if err := cn.err.get(); err != nil { + return err } defer cn.errRecover(&err) return cn.rollback() @@ -618,7 +637,7 @@ func (cn *conn) rollback() (err error) { _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.setBad() + cn.err.set(driver.ErrBadConn) } return err } @@ -658,7 +677,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unknown response for simple query: %q", t) } } @@ -680,7 +699,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -707,7 +726,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -722,7 +741,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unknown response for simple query: %q", t) } } @@ -815,8 +834,8 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.getBad() { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -854,8 +873,8 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.getBad() { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } if cn.inCopy { return nil, errCopyInProgress @@ -888,8 +907,8 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.getBad() { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -960,7 +979,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -1330,8 +1349,8 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.getBad() { - return driver.ErrBadConn + if err := st.cn.err.get(); err != nil { + return err } defer st.cn.errRecover(&err) @@ -1344,14 +1363,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.setBad() + st.cn.err.set(driver.ErrBadConn) errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.setBad() + st.cn.err.set(driver.ErrBadConn) errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1364,8 +1383,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { } func (st *stmt) query(v []driver.Value) (r *rows, err error) { - if st.cn.getBad() { - return nil, driver.ErrBadConn + if err := st.cn.err.get(); err != nil { + return nil, err } defer st.cn.errRecover(&err) @@ -1377,8 +1396,8 @@ func (st *stmt) query(v []driver.Value) (r *rows, err error) { } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.getBad() { - return nil, driver.ErrBadConn + if err := st.cn.err.get(); err != nil { + return nil, err } defer st.cn.errRecover(&err) @@ -1464,7 +1483,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1476,7 +1495,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag @@ -1543,8 +1562,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.getBad() { - return driver.ErrBadConn + if err := conn.err.getForNext(); err != nil { + return err } defer conn.errRecover(&err) @@ -1568,7 +1587,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.setBad() + conn.err.set(driver.ErrBadConn) errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1762,7 +1781,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1782,7 +1801,7 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected Parse response %q", t) } } @@ -1807,7 +1826,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected Describe statement response %q", t) } } @@ -1825,7 +1844,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader { cn.readReadyForQuery() panic(err) default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1841,7 +1860,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected Bind response %q", t) } } @@ -1868,7 +1887,7 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected message during extended query execution: %q", t) } } @@ -1881,7 +1900,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co switch t { case 'C': if err != nil { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1895,7 +1914,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1903,7 +1922,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.setBad() + cn.err.set(driver.ErrBadConn) errorf("unknown %s response: %q", protocolState, t) } } diff --git a/conn_go18.go b/conn_go18.go index 3c83082b..63d4ca6a 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/ioutil" - "sync/atomic" "time" ) @@ -115,7 +114,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() { } // Set the connection state to bad so it does not get reused. - cn.setBad() + cn.err.set(ctx.Err()) // At this point the function level context is canceled, // so it must not be used for the additional network @@ -131,7 +130,7 @@ func (cn *conn) watchCancel(ctx context.Context) func() { return func() { select { case <-finished: - cn.setBad() + cn.err.set(ctx.Err()) cn.Close() case finished <- struct{}{}: } @@ -157,11 +156,8 @@ func (cn *conn) cancel(ctx context.Context) error { defer c.Close() { - bad := &atomic.Value{} - bad.Store(false) can := conn{ - c: c, - bad: bad, + c: c, } err = can.ssl(o) if err != nil { diff --git a/conn_test.go b/conn_test.go index f6bbfacb..b32f983a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,7 +10,6 @@ import ( "os" "reflect" "strings" - "sync/atomic" "testing" "time" ) @@ -696,9 +695,7 @@ func TestErrorDuringStartupClosesConn(t *testing.T) { func TestBadConn(t *testing.T) { var err error - bad := &atomic.Value{} - bad.Store(false) - cn := conn{bad: bad} + cn := conn{} func() { defer cn.errRecover(&err) panic(io.EOF) @@ -706,13 +703,11 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.getBad() { - t.Fatalf("expected cn.bad") + if err := cn.err.get(); err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got %#v", err) } - badd := &atomic.Value{} - badd.Store(false) - cn = conn{bad: badd} + cn = conn{} func() { defer cn.errRecover(&err) e := &Error{Severity: Efatal} @@ -721,8 +716,8 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.getBad() { - t.Fatalf("expected cn.bad") + if err := cn.err.get(); err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got %#v", err) } } diff --git a/copy.go b/copy.go index bb3cbd7b..c072bc3b 100644 --- a/copy.go +++ b/copy.go @@ -49,12 +49,14 @@ type copyin struct { buffer []byte rowData chan []byte done chan bool - driver.Result closed bool - sync.Mutex // guards err - err error + mu struct { + sync.Mutex + err error + driver.Result + } } const ciBufferSize = 64 * 1024 @@ -98,13 +100,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unknown response for copy query: %q", t) } } @@ -123,7 +125,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unknown response for CopyFail: %q", t) } } @@ -144,7 +146,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.setBad() + ci.setBad(driver.ErrBadConn) ci.setError(err) ci.done <- true return @@ -166,7 +168,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.setBad() + ci.setBad(driver.ErrBadConn) ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -174,46 +176,41 @@ func (ci *copyin) resploop() { } } -func (ci *copyin) setBad() { - ci.Lock() - ci.cn.setBad() - ci.Unlock() +func (ci *copyin) setBad(err error) { + ci.cn.err.set(err) } -func (ci *copyin) isBad() bool { - ci.Lock() - b := ci.cn.getBad() - ci.Unlock() - return b +func (ci *copyin) getBad() error { + return ci.cn.err.get() } -func (ci *copyin) isErrorSet() bool { - ci.Lock() - isSet := (ci.err != nil) - ci.Unlock() - return isSet +func (ci *copyin) err() error { + ci.mu.Lock() + err := ci.mu.err + ci.mu.Unlock() + return err } // setError() sets ci.err if one has not been set already. Caller must not be // holding ci.Mutex. func (ci *copyin) setError(err error) { - ci.Lock() - if ci.err == nil { - ci.err = err + ci.mu.Lock() + if ci.mu.err == nil { + ci.mu.err = err } - ci.Unlock() + ci.mu.Unlock() } func (ci *copyin) setResult(result driver.Result) { - ci.Lock() - ci.Result = result - ci.Unlock() + ci.mu.Lock() + ci.mu.Result = result + ci.mu.Unlock() } func (ci *copyin) getResult() driver.Result { - ci.Lock() - result := ci.Result - ci.Unlock() + ci.mu.Lock() + result := ci.mu.Result + ci.mu.Unlock() if result == nil { return driver.RowsAffected(0) } @@ -240,13 +237,13 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.isBad() { - return nil, driver.ErrBadConn + if err := ci.getBad(); err != nil { + return nil, err } defer ci.cn.errRecover(&err) - if ci.isErrorSet() { - return nil, ci.err + if err := ci.err(); err != nil { + return nil, err } if len(v) == 0 { @@ -282,8 +279,8 @@ func (ci *copyin) Close() (err error) { } ci.closed = true - if ci.isBad() { - return driver.ErrBadConn + if err := ci.getBad(); err != nil { + return err } defer ci.cn.errRecover(&err) @@ -299,8 +296,7 @@ func (ci *copyin) Close() (err error) { <-ci.done ci.cn.inCopy = false - if ci.isErrorSet() { - err = ci.err + if err := ci.err(); err != nil { return err } return nil diff --git a/error.go b/error.go index b0f53755..5cfe9c6e 100644 --- a/error.go +++ b/error.go @@ -484,7 +484,7 @@ func (cn *conn) errRecover(err *error) { case nil: // Do nothing case runtime.Error: - cn.setBad() + cn.err.set(driver.ErrBadConn) panic(v) case *Error: if v.Fatal() { @@ -493,10 +493,10 @@ func (cn *conn) errRecover(err *error) { *err = v } case *net.OpError: - cn.setBad() + cn.err.set(driver.ErrBadConn) *err = v case *safeRetryError: - cn.setBad() + cn.err.set(driver.ErrBadConn) *err = driver.ErrBadConn case error: if v == io.EOF || v.Error() == "remote error: handshake failure" { @@ -506,13 +506,13 @@ func (cn *conn) errRecover(err *error) { } default: - cn.setBad() + cn.err.set(driver.ErrBadConn) panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - cn.setBad() + cn.err.set(driver.ErrBadConn) } } diff --git a/go18_test.go b/go18_test.go index 95c08cd7..27501e74 100644 --- a/go18_test.go +++ b/go18_test.go @@ -143,7 +143,7 @@ func TestContextCancelQuery(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := rows.Close(); err != nil && err != driver.ErrBadConn { + } else if err := rows.Close(); err != nil && err != driver.ErrBadConn && err != context.Canceled { t.Fatal(err) } }() @@ -242,7 +242,7 @@ func TestContextCancelBegin(t *testing.T) { t.Fatal(err) } else if err := tx.Rollback(); err != nil && err.Error() != "pq: canceling statement due to user request" && - err != sql.ErrTxDone && err != driver.ErrBadConn { + err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled { t.Fatal(err) } }() diff --git a/issues_test.go b/issues_test.go index 55d3f1ec..4d24c9dd 100644 --- a/issues_test.go +++ b/issues_test.go @@ -58,3 +58,22 @@ func TestIssue1046(t *testing.T) { t.Fail() } } + +func TestIssue1062(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // Ensure that cancelling a QueryRowContext does not result in an ErrBadConn. + + for i := 0; i < 100; i++ { + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + row := db.QueryRowContext(ctx, "select 1") + + var v int + err := row.Scan(&v) + if err != nil && err != context.Canceled && err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1) + } + } +}