From 97eab81b35c458dfda6f0dbad91293c6c5d93e95 Mon Sep 17 00:00:00 2001 From: zhaokongsheng Date: Sat, 1 Dec 2018 15:46:32 +0800 Subject: [PATCH 1/5] Add context --- orm/db.go | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/orm/db.go b/orm/db.go index 25666eb..c97052d 100644 --- a/orm/db.go +++ b/orm/db.go @@ -181,7 +181,11 @@ type DBTx struct { } func (store *DBStore) BeginTx() (*DBTx, error) { - tx, err := store.Begin() + return store.BeginTxContext(context.Background()) +} + +func (store *DBStore) BeginTxContext(ctx context.Context) (*DBTx, error) { + tx, err := store.DB.BeginTx(ctx, nil) if err != nil { return nil, err } @@ -202,6 +206,10 @@ func (tx *DBTx) Close() error { } func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { + return tx.QueryContext(context.Background(), sql, args...) +} + +func (tx *DBTx) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -214,7 +222,7 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { if tx.debug { log.Println("DEBUG: ", sql, args) } - result, err := tx.tx.Query(sql, args...) + result, err := tx.tx.QueryContext(ctx, sql, args...) if err != nil { tx.err = err } @@ -222,6 +230,10 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { } func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) { + return tx.ExecContext(context.Background(), sql, args...) +} + +func (tx *DBTx) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -289,6 +301,28 @@ func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) { return err } +func TransactFuncContext(ctx context.Context, db *DBStore, txFunc func(ctx context.Context, tx *DBTx) error) (err error) { + tx, err := db.BeginTxContext(ctx) + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + tx.SetError(fmt.Errorf("panic: %v", p)) + tx.Close() + panic(p) + } else if err != nil { + tx.SetError(err) + tx.Close() + } else { + err = tx.Close() + } + }() + + err = txFunc(ctx, tx) + return err +} + type Transactor interface { Transact(tx *DBTx) error } @@ -296,3 +330,11 @@ type Transactor interface { func Transact(db *DBStore, t Transactor) error { return TransactFunc(db, t.Transact) } + +type TransactorWithContext interface { + TransactContext(ctx context.Context, tx *DBTx) error +} + +func TransactContext(ctx context.Context, db *DBStore, t TransactorWithContext) error { + return TransactFuncContext(ctx, db, t.TransactContext) +} From e4d8b88b1ccc524f8a783f7ffd56c4a1c1e55d20 Mon Sep 17 00:00:00 2001 From: zhaokongsheng Date: Sat, 1 Dec 2018 16:48:40 +0800 Subject: [PATCH 2/5] Add after commit feature --- orm/db.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/orm/db.go b/orm/db.go index c97052d..c0165a6 100644 --- a/orm/db.go +++ b/orm/db.go @@ -178,6 +178,7 @@ type DBTx struct { err error rowsAffected int64 wrappers []database.Wrapper + afterCommit func(err error) } func (store *DBStore) BeginTx() (*DBTx, error) { @@ -202,7 +203,9 @@ func (tx *DBTx) Close() error { if tx.err != nil { return tx.tx.Rollback() } - return tx.tx.Commit() + err := tx.tx.Commit() + tx.afterCommit(err) + return err } func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { @@ -279,6 +282,10 @@ func (tx *DBTx) SetError(err error) { tx.err = err } +func (tx *DBTx) AfterCommit(afterCommit func(err error)) { + tx.afterCommit = afterCommit +} + func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) { tx, err := db.BeginTx() if err != nil { @@ -338,3 +345,14 @@ type TransactorWithContext interface { func TransactContext(ctx context.Context, db *DBStore, t TransactorWithContext) error { return TransactFuncContext(ctx, db, t.TransactContext) } + +func BeginTx(ctx context.Context, db *sql.DB) (*DBTx, error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + return &DBTx{ + tx: tx, + }, nil +} From aa2bf619d77571acb70c713810841b31fc52c477 Mon Sep 17 00:00:00 2001 From: zhaokongsheng Date: Tue, 25 Dec 2018 17:17:53 +0800 Subject: [PATCH 3/5] Fix nil after commit --- orm/db.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orm/db.go b/orm/db.go index c0165a6..f4df3b2 100644 --- a/orm/db.go +++ b/orm/db.go @@ -204,7 +204,9 @@ func (tx *DBTx) Close() error { return tx.tx.Rollback() } err := tx.tx.Commit() - tx.afterCommit(err) + if tx.afterCommit != nil { + tx.afterCommit(err) + } return err } From 09170ea5ffdb87eeecfe06988559811f17f128b0 Mon Sep 17 00:00:00 2001 From: zhaokongsheng Date: Fri, 4 Jan 2019 18:58:22 +0800 Subject: [PATCH 4/5] Add GetStdTx --- orm/db.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/orm/db.go b/orm/db.go index f4df3b2..41aaca4 100644 --- a/orm/db.go +++ b/orm/db.go @@ -288,6 +288,10 @@ func (tx *DBTx) AfterCommit(afterCommit func(err error)) { tx.afterCommit = afterCommit } +func (tx *DBTx) GetStdTx() *sql.Tx { + return tx.tx +} + func TransactFunc(db *DBStore, txFunc func(*DBTx) error) (err error) { tx, err := db.BeginTx() if err != nil { From 8c3161635f40707cabe29aad666aedfbcf9747bb Mon Sep 17 00:00:00 2001 From: zhaokongsheng Date: Fri, 15 Mar 2019 15:34:41 +0800 Subject: [PATCH 5/5] Fix duplicate def --- orm/db.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/orm/db.go b/orm/db.go index 41aaca4..79c5f95 100644 --- a/orm/db.go +++ b/orm/db.go @@ -214,7 +214,7 @@ func (tx *DBTx) Query(sql string, args ...interface{}) (*sql.Rows, error) { return tx.QueryContext(context.Background(), sql, args...) } -func (tx *DBTx) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { +func (tx *DBTx) queryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -238,7 +238,7 @@ func (tx *DBTx) Exec(sql string, args ...interface{}) (sql.Result, error) { return tx.ExecContext(context.Background(), sql, args...) } -func (tx *DBTx) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { +func (tx *DBTx) execContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { t1 := time.Now() if tx.slowlog > 0 { defer func(t time.Time) { @@ -261,7 +261,7 @@ func (tx *DBTx) ExecContext(ctx context.Context, sql string, args ...interface{} func (tx *DBTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { fn := func(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - return tx.tx.QueryContext(ctx, query, args...) + return tx.queryContext(ctx, query, args...) } for _, wp := range tx.wrappers { fn = wp.WrapQueryContext(fn, query, args...) @@ -272,7 +272,7 @@ func (tx *DBTx) QueryContext(ctx context.Context, query string, func (tx *DBTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { fn := func(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - return tx.tx.ExecContext(ctx, query, args...) + return tx.execContext(ctx, query, args...) } for _, wp := range tx.wrappers { fn = wp.WrapExecContext(fn, query, args...)