Skip to content

Commit

Permalink
Merge pull request dtm-labs#71 from yedf/alpha
Browse files Browse the repository at this point in the history
barrier interface change to sql.Tx
  • Loading branch information
yedf2 authored Dec 4, 2021
2 parents 4b861e7 + 5fc7e5a commit c50a972
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 45 deletions.
2 changes: 1 addition & 1 deletion bench/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func qsAdjustBalance(uid int, amount int, c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, nil
}
tb := dtmimp.TransBaseFromQuery(c.Request.URL.Query())
f := func(tx dtmcli.DB) error {
f := func(tx *sql.Tx) error {
for i := 0; i < sqls; i++ {
_, err := dtmimp.DBExec(tx, "insert into dtm_busi.user_account_log(user_id, delta, gid, branch_id, op, reason) values(?,?,?,?,?,?)",
uid, amount, tb.Gid, c.Query("branch_id"), tb.TransType, fmt.Sprintf("inserted by dtm transaction %s %s", tb.Gid, c.Query("branch_id")))
Expand Down
16 changes: 13 additions & 3 deletions dtmcli/barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
package dtmcli

import (
"database/sql"
"fmt"
"net/url"

"github.com/yedf/dtm/dtmcli/dtmimp"
)

// BarrierBusiFunc type for busi func
type BarrierBusiFunc func(db DB) error
type BarrierBusiFunc func(tx *sql.Tx) error

// BranchBarrier every branch info
type BranchBarrier struct {
Expand Down Expand Up @@ -48,7 +49,7 @@ func BarrierFrom(transType, gid, branchID, op string) (*BranchBarrier, error) {
return ti, nil
}

func insertBarrier(tx Tx, transType string, gid string, branchID string, op string, barrierID string, reason string) (int64, error) {
func insertBarrier(tx DB, transType string, gid string, branchID string, op string, barrierID string, reason string) (int64, error) {
if op == "" {
return 0, nil
}
Expand All @@ -59,7 +60,7 @@ func insertBarrier(tx Tx, transType string, gid string, branchID string, op stri
// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
// tx: 本地数据库的事务对象,允许子事务屏障进行事务操作
// busiCall: 业务函数,仅在必要时被调用
func (bb *BranchBarrier) Call(tx Tx, busiCall BarrierBusiFunc) (rerr error) {
func (bb *BranchBarrier) Call(tx *sql.Tx, busiCall BarrierBusiFunc) (rerr error) {
bb.BarrierID = bb.BarrierID + 1
bid := fmt.Sprintf("%02d", bb.BarrierID)
defer func() {
Expand Down Expand Up @@ -89,3 +90,12 @@ func (bb *BranchBarrier) Call(tx Tx, busiCall BarrierBusiFunc) (rerr error) {
rerr = busiCall(tx)
return
}

// CallWithDB the same as Call, but with *sql.DB
func (bb *BranchBarrier) CallWithDB(db *sql.DB, busiCall BarrierBusiFunc) error {
tx, err := db.Begin()
if err != nil {
return err
}
return bb.Call(tx, busiCall)
}
7 changes: 0 additions & 7 deletions dtmcli/dtmimp/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,3 @@ type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error)
QueryRow(query string, args ...interface{}) *sql.Row
}

// Tx interface of dtmcli tx
type Tx interface {
Rollback() error
Commit() error
DB
}
3 changes: 0 additions & 3 deletions dtmcli/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ func MustGenGid(server string) string {
// DB interface
type DB = dtmimp.DB

// Tx interface
type Tx = dtmimp.Tx

// TransOptions transaction option
type TransOptions = dtmimp.TransOptions

Expand Down
15 changes: 8 additions & 7 deletions examples/grpc_saga_barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples

import (
"context"
"database/sql"

"github.com/yedf/dtm/dtmcli"
"github.com/yedf/dtm/dtmcli/dtmimp"
Expand Down Expand Up @@ -41,28 +42,28 @@ func sagaGrpcBarrierAdjustBalance(db dtmcli.DB, uid int, amount int64, result st

func (s *busiServer) TransInBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx dtmcli.DB) error {
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(tx, 2, in.Amount, in.TransInResult)
})
}

func (s *busiServer) TransOutBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 1, -in.Amount, in.TransOutResult)
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(tx, 1, -in.Amount, in.TransOutResult)
})
}

func (s *busiServer) TransInRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 2, -in.Amount, "")
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(tx, 2, -in.Amount, "")
})
}

func (s *busiServer) TransOutRevertBSaga(ctx context.Context, in *BusiReq) (*emptypb.Empty, error) {
barrier := MustBarrierFromGrpc(ctx)
return &emptypb.Empty{}, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaGrpcBarrierAdjustBalance(db, 1, in.Amount, "")
return &emptypb.Empty{}, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaGrpcBarrierAdjustBalance(tx, 1, in.Amount, "")
})
}
18 changes: 10 additions & 8 deletions examples/http_saga_barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package examples

import (
"database/sql"

"github.com/gin-gonic/gin"
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
Expand Down Expand Up @@ -45,15 +47,15 @@ func sagaBarrierTransIn(c *gin.Context) (interface{}, error) {
return req.TransInResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 1, req.Amount)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaBarrierAdjustBalance(tx, 1, req.Amount)
})
}

func sagaBarrierTransInCompensate(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 1, -reqFrom(c).Amount)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaBarrierAdjustBalance(tx, 1, -reqFrom(c).Amount)
})
}

Expand All @@ -63,14 +65,14 @@ func sagaBarrierTransOut(c *gin.Context) (interface{}, error) {
return req.TransOutResult, nil
}
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 2, -req.Amount)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaBarrierAdjustBalance(tx, 2, -req.Amount)
})
}

func sagaBarrierTransOutCompensate(c *gin.Context) (interface{}, error) {
barrier := MustBarrierFromGin(c)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(db dtmcli.DB) error {
return sagaBarrierAdjustBalance(db, 2, reqFrom(c).Amount)
return dtmcli.MapSuccess, barrier.Call(txGet(), func(tx *sql.Tx) error {
return sagaBarrierAdjustBalance(tx, 2, reqFrom(c).Amount)
})
}
2 changes: 1 addition & 1 deletion examples/http_saga_gorm_barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func sagaGormBarrierTransOut(c *gin.Context) (interface{}, error) {
req := reqFrom(c)
barrier := MustBarrierFromGin(c)
tx := dbGet().DB.Begin()
return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(db dtmcli.DB) error {
return dtmcli.MapSuccess, barrier.Call(tx.Statement.ConnPool.(*sql.Tx), func(tx1 *sql.Tx) error {
return tx.Exec("update dtm_busi.user_account set balance = balance + ? where user_id = ?", -req.Amount, 2).Error
})
}
25 changes: 13 additions & 12 deletions examples/http_tcc_barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package examples

import (
"database/sql"
"fmt"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -68,20 +69,20 @@ func tccBarrierTransInTry(c *gin.Context) (interface{}, error) {
if req.TransInResult != "" {
return req.TransInResult, nil
}
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transInUID, req.Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustTrading(tx, transInUID, req.Amount)
})
}

func tccBarrierTransInConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustBalance(db, transInUID, reqFrom(c).Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustBalance(tx, transInUID, reqFrom(c).Amount)
})
}

func tccBarrierTransInCancel(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transInUID, -reqFrom(c).Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustTrading(tx, transInUID, -reqFrom(c).Amount)
})
}

Expand All @@ -90,20 +91,20 @@ func tccBarrierTransOutTry(c *gin.Context) (interface{}, error) {
if req.TransOutResult != "" {
return req.TransOutResult, nil
}
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transOutUID, -req.Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustTrading(tx, transOutUID, -req.Amount)
})
}

func tccBarrierTransOutConfirm(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustBalance(db, transOutUID, -reqFrom(c).Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustBalance(tx, transOutUID, -reqFrom(c).Amount)
})
}

// TccBarrierTransOutCancel will be use in test
func TccBarrierTransOutCancel(c *gin.Context) (interface{}, error) {
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(db dtmcli.DB) error {
return adjustTrading(db, transOutUID, reqFrom(c).Amount)
return dtmcli.MapSuccess, MustBarrierFromGin(c).Call(txGet(), func(tx *sql.Tx) error {
return adjustTrading(tx, transOutUID, reqFrom(c).Amount)
})
}
5 changes: 3 additions & 2 deletions test/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package test

import (
"database/sql"
"fmt"
"testing"

Expand Down Expand Up @@ -38,7 +39,7 @@ func TestBaseSqlDB(t *testing.T) {
db.Must().Exec("insert into dtm_barrier.barrier(trans_type, gid, branch_id, op, reason) values('saga', 'gid1', 'branch_id1', 'action', 'saga')")
tx, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx, func(db dtmcli.DB) error {
err = barrier.Call(tx, func(tx *sql.Tx) error {
dtmimp.Logf("rollback gid2")
return fmt.Errorf("gid2 error")
})
Expand All @@ -50,7 +51,7 @@ func TestBaseSqlDB(t *testing.T) {
barrier.BarrierID = 0
tx2, err := db.ToSQLDB().Begin()
asserts.Nil(err)
err = barrier.Call(tx2, func(db dtmcli.DB) error {
err = barrier.Call(tx2, func(tx *sql.Tx) error {
dtmimp.Logf("submit gid2")
return nil
})
Expand Down
2 changes: 1 addition & 1 deletion test/tcc_barrier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func TestTccBarrierPanic(t *testing.T) {
func() {
defer dtmimp.P2E(&err)
tx, _ := dbGet().ToSQLDB().BeginTx(context.Background(), &sql.TxOptions{})
bb.Call(tx, func(db dtmcli.DB) error {
bb.Call(tx, func(tx *sql.Tx) error {
panic(fmt.Errorf("an error"))
})
}()
Expand Down

0 comments on commit c50a972

Please sign in to comment.