diff --git a/orm/db.go b/orm/db.go index 25666eb..79c5f95 100644 --- a/orm/db.go +++ b/orm/db.go @@ -178,10 +178,15 @@ type DBTx struct { err error rowsAffected int64 wrappers []database.Wrapper + afterCommit func(err error) } func (store *DBStore) BeginTx() (*DBTx, error) { - tx, err := store.Begin() + return store.BeginTxContext(context.Background()) +} + +func (store *DBStore) BeginTxContext(ctx context.Context) (*DBTx, error) { + tx, err := store.DB.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -198,10 +203,18 @@ func (tx *DBTx) Close() error { if tx.err != nil { return tx.tx.Rollback() } - return tx.tx.Commit() + err := tx.tx.Commit() + if tx.afterCommit != nil { + tx.afterCommit(err) + } + return err } func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { + return tx.QueryContext(context.Background(), sql, args...) +} + +func (tx *DBTx) queryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -214,7 +227,7 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { if tx.debug { log.Println("DEBUG: ", sql, args) } - result, err := tx.tx.Query(sql, args...) + result, err := tx.tx.QueryContext(ctx, sql, args...) if err != nil { tx.err = err } @@ -222,6 +235,10 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { } func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) { + return tx.ExecContext(context.Background(), sql, args...) +} + +func (tx *DBTx) execContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -244,7 +261,7 @@ func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) { func (tx *DBTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { fn := func(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - return tx.tx.QueryContext(ctx, query, args...) + return tx.queryContext(ctx, query, args...) } for _, wp := range tx.wrappers { fn = wp.WrapQueryContext(fn, query, args...) @@ -255,7 +272,7 @@ func (tx *DBTx) QueryContext(ctx context.Context, query string, func (tx *DBTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { fn := func(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - return tx.tx.ExecContext(ctx, query, args...) + return tx.execContext(ctx, query, args...) } for _, wp := range tx.wrappers { fn = wp.WrapExecContext(fn, query, args...) @@ -267,6 +284,14 @@ func (tx *DBTx) SetError(err error) { tx.err = err } +func (tx *DBTx) AfterCommit(afterCommit func(err error)) { + tx.afterCommit = afterCommit +} + +func (tx *DBTx) GetStdTx() *sql.Tx { + return tx.tx +} + func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) { tx, err := db.BeginTx() if err != nil { @@ -289,6 +314,28 @@ func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) { return err } +func TransactFuncContext(ctx context.Context, db *DBStore, txFunc func(ctx context.Context, tx *DBTx) error) (err error) { + tx, err := db.BeginTxContext(ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + tx.SetError(fmt.Errorf("panic: %v", p)) + tx.Close() + panic(p) + } else if err != nil { + tx.SetError(err) + tx.Close() + } else { + err = tx.Close() + } + }() + + err = txFunc(ctx, tx) + return err +} + type Transactor interface { Transact(tx *DBTx) error } @@ -296,3 +343,22 @@ type Transactor interface { func Transact(db *DBStore, t Transactor) error { return TransactFunc(db, t.Transact) } + +type TransactorWithContext interface { + TransactContext(ctx context.Context, tx *DBTx) error +} + +func TransactContext(ctx context.Context, db *DBStore, t TransactorWithContext) error { + return TransactFuncContext(ctx, db, t.TransactContext) +} + +func BeginTx(ctx context.Context, db *sql.DB) (*DBTx, error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + return &DBTx{ + tx: tx, + }, nil +}