Skip to content

Commit

Permalink
Create ctx from testing context, instead of using context.Background
Browse files Browse the repository at this point in the history
  • Loading branch information
reductionista committed Nov 1, 2024
1 parent c540333 commit 8e479fe
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pkg/pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func NewSqlxDB(t testing.TB, dbURL string) *sqlx.DB {
tests.SkipShortDB(t)
err := RegisterTxDb(dbURL)
err := RegisterTxDb(tests.Context(t), dbURL)
if err != nil {
t.Errorf("failed to register txdb dialect: %s", err.Error())
return nil
Expand Down
68 changes: 48 additions & 20 deletions pkg/pg/txdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

"github.com/jmoiron/sqlx"
"go.uber.org/multierr"

"github.com/smartcontractkit/chainlink-common/pkg/utils"
)

// txdb is a simplified version of https://github.com/DATA-DOG/go-txdb
Expand All @@ -32,7 +34,7 @@ import (
// store to use the raw DialectPostgres dialect and setup a one-use database.
// See heavyweight.FullTestDB() as a convenience function to help you do this,
// but please use sparingly because as it's name implies, it is expensive.
func RegisterTxDb(dbURL string) error {
func RegisterTxDb(ctx context.Context, dbURL string) error {
drivers := sql.Drivers()
for _, driver := range drivers {
if driver == string(TransactionWrappedPostgres) {
Expand All @@ -58,8 +60,15 @@ func RegisterTxDb(dbURL string) error {
if !strings.HasSuffix(parsed.Path, "_test") {
return fmt.Errorf("cannot run tests against database named `%s`. Note that the test database MUST end in `_test` to differentiate from a possible production DB. HINT: Try postgresql://postgres@localhost:5432/chainlink_test?sslmode=disable", parsed.Path[1:])
}
abort := make(chan struct{})
go func() {
<-ctx.Done()
abort <- struct{}{} // abort all queries when context is cancelled
}()

name := string(TransactionWrappedPostgres)
sql.Register(name, &txDriver{
abort: abort,
dbURL: dbURL,
conns: make(map[string]*conn),
})
Expand All @@ -76,6 +85,7 @@ var _ driver.SessionResetter = &conn{}
// When `Close` is called, transaction is rolled back.
type txDriver struct {
sync.Mutex
abort <-chan struct{}
db *sql.DB
conns map[string]*conn

Expand All @@ -99,7 +109,7 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c = &conn{tx: tx, opened: 1, dsn: dsn}
c = &conn{abort: d.abort, tx: tx, opened: 1, dsn: dsn}
c.removeSelf = func() error {
return d.deleteConn(c)
}
Expand Down Expand Up @@ -130,6 +140,7 @@ func (d *txDriver) deleteConn(c *conn) error {

type conn struct {
sync.Mutex
abort <-chan struct{}
dsn string
tx *sql.Tx // tx may be shared by many conns, definitive one lives in the map keyed by DSN on the txDriver. Do not modify from conn
closed bool
Expand All @@ -156,26 +167,32 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err

// Prepare returns a prepared statement, bound to this connection.
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
ctx, cancel := utils.ContextFromChan(c.abort)
defer cancel()
return c.PrepareContext(ctx, query)
}

// Implement the "ConnPrepareContext" interface
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
c.Lock()
defer c.Unlock()
if c.closed {
panic("conn is closed")
}

// TODO: Fix context handling
// FIXME: It is not safe to give the passed in context to the tx directly
// It is not safe to give the passed in context to the tx directly
// because the tx is shared by many conns and cancelling the context will
// destroy the tx which can affect other conns
st, err := c.tx.PrepareContext(context.Background(), query)
// destroy the tx which can affect other conns. Instead, we pass the context
// passed to NewSqlxDb when the database was set up so the operation can at
// least be aborted immediately if the whole test is interrupted.
ctx, cancel := utils.ContextFromChan(c.abort)
defer cancel()

st, err := c.tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &stmt{st, c}, nil
return &stmt{c.abort, st, c}, nil
}

// IsValid is called prior to placing the connection into the
Expand Down Expand Up @@ -212,8 +229,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
panic("conn is closed")
}

// TODO: Fix context handling
rs, err := c.tx.QueryContext(context.Background(), query, mapNamedArgs(args)...)
ctx, cancel := utils.ContextFromChan(c.abort)
defer cancel()

rs, err := c.tx.QueryContext(ctx, query, mapNamedArgs(args)...)
if err != nil {
return nil, err
}
Expand All @@ -229,8 +248,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if c.closed {
return nil, fmt.Errorf("conn is closed")
}
// TODO: Fix context handling
return c.tx.ExecContext(context.Background(), query, mapNamedArgs(args)...)
ctx, cancel := utils.ContextFromChan(c.abort)
defer cancel()

return c.tx.ExecContext(ctx, query, mapNamedArgs(args)...)
}

// tryOpen attempts to increment the open count, but returns false if closed.
Expand Down Expand Up @@ -305,8 +326,9 @@ func (tx tx) Rollback() error {
}

type stmt struct {
st *sql.Stmt
conn *conn
abort <-chan struct{}
st *sql.Stmt
conn *conn
}

func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
Expand All @@ -325,8 +347,11 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
if s.conn.closed {
panic("conn is closed")
}
// TODO: Fix context handling
return s.st.ExecContext(context.Background(), mapNamedArgs(args)...)

ctx, cancel := utils.ContextFromChan(s.abort)
defer cancel()

return s.st.ExecContext(ctx, mapNamedArgs(args)...)
}

func mapArgs(args []driver.Value) (res []interface{}) {
Expand Down Expand Up @@ -358,14 +383,17 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
}

// Implement the "StmtQueryContext" interface
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
func (s *stmt) QueryContext(_ context.Context, args []driver.NamedValue) (driver.Rows, error) {
s.conn.Lock()
defer s.conn.Unlock()
if s.conn.closed {
panic("conn is closed")
}
// TODO: Fix context handling
rows, err := s.st.QueryContext(context.Background(), mapNamedArgs(args)...)

ctx, cancel := utils.ContextFromChan(s.abort)
defer cancel()

rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...)
if err != nil {
return nil, err
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/pg/txdb_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pg

import (
"database/sql"
"os"
"testing"
"time"
Expand Down Expand Up @@ -52,4 +53,11 @@ func TestTxDBDriver(t *testing.T) {
time.Sleep(time.Second * 10)
ensureValuesPresent(t, db)
})

t.Run("Make sure calling sql.Register() can be called twice", func(t *testing.T) {
require.NoError(t, RegisterTxDb(tests.Context(t), "foo"))
require.NoError(t, RegisterTxDb(tests.Context(t), "bar"))
drivers := sql.Drivers()
assert.Contains(t, drivers, "txdb")
})
}

0 comments on commit 8e479fe

Please sign in to comment.