Skip to content

Commit

Permalink
postgres supported
Browse files Browse the repository at this point in the history
  • Loading branch information
yedf2 committed Oct 9, 2021
1 parent b2dda03 commit 9518a93
Show file tree
Hide file tree
Showing 19 changed files with 359 additions and 36 deletions.
11 changes: 7 additions & 4 deletions common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
"sync"
"time"

_ "github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql" // register mysql driver
_ "github.com/lib/pq" // register postgres driver
"gopkg.in/yaml.v2"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"

"github.com/yedf/dtm/dtmcli"
Expand All @@ -26,10 +28,11 @@ type ModelBase struct {
}

func getGormDialetor(driver string, dsn string) gorm.Dialector {
if driver == "mysql" {
return mysql.Open(dsn)
if driver == dtmcli.DriverPostgres {
return postgres.Open(dsn)
}
panic(fmt.Errorf("unkown driver: %s", driver))
dtmcli.PanicIf(driver != dtmcli.DriverMysql, fmt.Errorf("unkown driver: %s", driver))
return mysql.Open(dsn)
}

var dbs sync.Map
Expand Down
2 changes: 2 additions & 0 deletions common/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func TestDbAlone(t *testing.T) {
assert.Nil(t, err)
_, err = dtmcli.DBExec(db, "select 1")
assert.Equal(t, nil, err)
_, err = dtmcli.DBExec(db, "")
assert.Equal(t, nil, err)
db.Close()
_, err = dtmcli.DBExec(db, "select 1")
assert.NotEqual(t, nil, err)
Expand Down
6 changes: 5 additions & 1 deletion dtmcli/barrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ func insertBarrier(tx Tx, transType string, gid string, branchID string, branchT
if branchType == "" {
return 0, nil
}
return DBExec(tx, "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)", transType, gid, branchID, branchType, barrierID, reason)
sql := map[string]string{
"mysql": "insert ignore into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?)",
"postgres": "insert into dtm_barrier.barrier(trans_type, gid, branch_id, branch_type, barrier_id, reason) values(?,?,?,?,?,?) on conflict ON CONSTRAINT uniq_barrier do nothing",
}[DBDriver]
return DBExec(tx, sql, transType, gid, branchID, branchType, barrierID, reason)
}

// Call 子事务屏障,详细介绍见 https://zhuanlan.zhihu.com/p/388444465
Expand Down
20 changes: 20 additions & 0 deletions dtmcli/barrier.postgres.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
create schema if not exists dtm_barrier;

drop table if exists dtm_barrier.barrier;

CREATE SEQUENCE if not EXISTS dtm_barrier.barrier_seq;

create table if not exists dtm_barrier.barrier(
id int NOT NULL DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'),
trans_type varchar(45) default '' ,
gid varchar(128) default'',
branch_id varchar(128) default '',
branch_type varchar(45) default '',
barrier_id varchar(45) default '',
reason varchar(45) default '',
create_time timestamp(0) DEFAULT NULL,
update_time timestamp(0) DEFAULT NULL,
PRIMARY KEY(id),
CONSTRAINT uniq_barrier unique(gid, branch_id, branch_type, barrier_id)
);

33 changes: 26 additions & 7 deletions dtmcli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -135,6 +136,7 @@ var FatalExitFunc = func() { os.Exit(1) }

// LogFatalf 采用红色打印错误类信息, 并退出
func LogFatalf(fmt string, args ...interface{}) {
fmt += "\n" + string(debug.Stack())
Logf("\x1b[31m\n"+fmt+"\x1b[0m\n", args...)
FatalExitFunc()
}
Expand Down Expand Up @@ -210,7 +212,10 @@ func StandaloneDB(conf map[string]string) (*sql.DB, error) {

// DBExec use raw db to exec
func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr error) {
sql = makeSqlCompatible(sql)
if sql == "" {
return 0, nil
}
sql = makeSQLCompatible(sql)
r, rerr := db.Exec(sql, values...)
if rerr == nil {
affected, rerr = r.RowsAffected()
Expand All @@ -223,7 +228,7 @@ func DBExec(db DB, sql string, values ...interface{}) (affected int64, rerr erro

// DBQueryRow use raw tx to query row
func DBQueryRow(db DB, query string, args ...interface{}) *sql.Row {
query = makeSqlCompatible(query)
query = makeSQLCompatible(query)
Logf("querying: "+query, args...)
return db.QueryRow(query, args...)
}
Expand Down Expand Up @@ -271,10 +276,8 @@ func CheckResult(res interface{}, err error) error {
return err
}

func makeSqlCompatible(sql string) string {
if DBDriver == DriverMysql {
return sql
} else if DBDriver == DriverPostgres {
func makeSQLCompatible(sql string) string {
if DBDriver == DriverPostgres {
pos := 1
parts := []string{}
b := 0
Expand All @@ -286,7 +289,23 @@ func makeSqlCompatible(sql string) string {
pos++
}
}
parts = append(parts, sql[b:])
return strings.Join(parts, "")
}
panic(fmt.Sprintf("unknown driver %s", DBDriver))
PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver))
return sql
}

func getXaSQL(action string, xid string) string {
if DBDriver == DriverPostgres {
return map[string]string{
"end": "",
"start": "begin",
"prepare": fmt.Sprintf("prepare transaction '%s'", xid),
"commit": fmt.Sprintf("commit prepared '%s'", xid),
"rollback": fmt.Sprintf("rollback prepared '%s'", xid),
}[action]
}
PanicIf(DBDriver != DriverMysql, fmt.Errorf("unkown db driver: %s", DBDriver))
return fmt.Sprintf("xa %s '%s'", action, xid)
}
8 changes: 5 additions & 3 deletions dtmcli/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ func TestFatal(t *testing.T) {
assert.Error(t, err, fmt.Errorf("fatal"))
}

func TestMakeSqlCompatible(t *testing.T) {
func TestCompatible(t *testing.T) {
old := DBDriver
DBDriver = DriverMysql
assert.Equal(t, "? ?", makeSqlCompatible("? ?"))
assert.Equal(t, "? ?", makeSQLCompatible("? ?"))
assert.Equal(t, "xa start 'xa1'", getXaSQL("start", "xa1"))
DBDriver = DriverPostgres
assert.Equal(t, "$1 $2", makeSqlCompatible("? ?"))
assert.Equal(t, "$1 $2", makeSQLCompatible("? ?"))
assert.Equal(t, "begin", getXaSQL("start", "xa1"))
DBDriver = old
}
12 changes: 6 additions & 6 deletions dtmcli/xa_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dtmcli

import (
"database/sql"
"fmt"
"strings"
)

Expand All @@ -21,8 +20,9 @@ func (xc *XaClientBase) HandleCallback(gid string, branchID string, action strin
}
defer db.Close()
xaID := gid + "-" + branchID
_, err = DBExec(db, fmt.Sprintf("xa %s '%s'", action, xaID))
if err != nil && strings.Contains(err.Error(), "Error 1397: XAER_NOTA") { // 重复commit/rollback同一个id,报这个错误,忽略
_, err = DBExec(db, getXaSQL(action, xaID))
if err != nil &&
(strings.Contains(err.Error(), "Error 1397: XAER_NOTA") || strings.Contains(err.Error(), "does not exist")) { // 重复commit/rollback同一个id,报这个错误,忽略
err = nil
}
return err
Expand All @@ -39,9 +39,9 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf
defer func() { db.Close() }()
defer func() {
x := recover()
_, err := DBExec(db, fmt.Sprintf("XA end '%s'", xaBranch))
_, err := DBExec(db, getXaSQL("end", xaBranch))
if x == nil && rerr == nil && err == nil {
_, err = DBExec(db, fmt.Sprintf("XA prepare '%s'", xaBranch))
_, err = DBExec(db, getXaSQL("prepare", xaBranch))
}
if rerr == nil {
rerr = err
Expand All @@ -50,7 +50,7 @@ func (xc *XaClientBase) HandleLocalTrans(xa *TransBase, cb func(*sql.DB) (interf
panic(x)
}
}()
_, rerr = DBExec(db, fmt.Sprintf("XA start '%s'", xaBranch))
_, rerr = DBExec(db, getXaSQL("start", xaBranch))
if rerr != nil {
return
}
Expand Down
11 changes: 9 additions & 2 deletions dtmsvr/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ func lockOneTrans(expireIn time.Duration) *TransGlobal {
trans := TransGlobal{}
owner := GenGid()
db := dbGet()
getTime := func(second int) string {
return fmt.Sprintf(map[string]string{
"mysql": "date_add(now(), interval %d second)",
"postgres": "current_timestamp + interval '%d second'",
}[dtmcli.DBDriver], second)
}
expire := int(expireIn / time.Second)
whereTime := fmt.Sprintf("next_cron_time < %s and next_cron_time > %s and update_time < %s", getTime(expire), getTime(-3600), getTime(expire-3))
// 这里next_cron_time需要限定范围,否则数据量累计之后,会导致查询变慢
// 限定update_time < now - 3,否则会出现刚被这个应用取出,又被另一个取出
dbr := db.Must().Model(&trans).
Where("next_cron_time < date_add(now(), interval ? second) and next_cron_time > date_add(now(), interval -3600 second) and update_time < date_add(now(), interval ? second) and status in ('prepared', 'aborting', 'submitted')", int(expireIn/time.Second), -3+int(expireIn/time.Second)).
Limit(1).Update("owner", owner)
Where(whereTime+"and status in ('prepared', 'aborting', 'submitted')").Limit(1).Update("owner", owner)
if dbr.RowsAffected == 0 {
return nil
}
Expand Down
3 changes: 2 additions & 1 deletion dtmsvr/dtmsvr.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ func updateBranchAsync() {
}
for len(updates) > 0 {
dbr := dbGet().Clauses(clause.OnConflict{
DoUpdates: clause.AssignmentColumns([]string{"status", "finish_time"}),
OnConstraint: "trans_branch_pkey",
DoUpdates: clause.AssignmentColumns([]string{"status", "finish_time"}),
}).Create(updates)
dtmcli.Logf("flushed %d branch status to db. affected: %d", len(updates), dbr.RowsAffected)
if dbr.Error != nil {
Expand Down
72 changes: 72 additions & 0 deletions dtmsvr/dtmsvr.postgres.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
CREATE SCHEMA if not EXISTS dtm /* SQLINES DEMO *** RACTER SET utf8mb4 */;

drop table IF EXISTS dtm.trans_global;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_global_seq;

CREATE TABLE if not EXISTS dtm.trans_global (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_global_seq'),
gid varchar(128) NOT NULL ,
trans_type varchar(45) not null ,
status varchar(45) NOT NULL ,
query_prepared varchar(128) NOT NULL ,
protocol varchar(45) not null,
create_time timestamp(0) DEFAULT NULL,
update_time timestamp(0) DEFAULT NULL,
commit_time timestamp(0) DEFAULT NULL,
finish_time timestamp(0) DEFAULT NULL,
rollback_time timestamp(0) DEFAULT NULL,
next_cron_interval int default null ,
next_cron_time timestamp(0) default null ,
owner varchar(128) not null default '' ,
PRIMARY KEY (id),
CONSTRAINT gid UNIQUE (gid)
) ;

create index if not EXISTS owner on dtm.trans_global(owner);
CREATE INDEX if not EXISTS create_time ON dtm.trans_global (create_time);
CREATE INDEX if not EXISTS update_time ON dtm.trans_global (update_time);
create index if not EXISTS next_cron_time on dtm.trans_global (next_cron_time);

drop table IF EXISTS dtm.trans_branch;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_branch_seq;

CREATE TABLE IF NOT EXISTS dtm.trans_branch (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_branch_seq'),
gid varchar(128) NOT NULL ,
url varchar(128) NOT NULL ,
data TEXT ,
branch_id VARCHAR(128) NOT NULL ,
branch_type varchar(45) NOT NULL ,
status varchar(45) NOT NULL ,
finish_time timestamp(0) DEFAULT NULL,
rollback_time timestamp(0) DEFAULT NULL,
create_time timestamp(0) DEFAULT NULL,
update_time timestamp(0) DEFAULT NULL,
PRIMARY KEY (id),
CONSTRAINT gid_uniq UNIQUE (gid,branch_id, branch_type)
) ;

CREATE INDEX if not EXISTS create_time ON dtm.trans_branch (create_time);
CREATE INDEX if not EXISTS update_time ON dtm.trans_branch (update_time);

drop table IF EXISTS dtm.trans_log;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
CREATE SEQUENCE if not EXISTS dtm.trans_log_seq;

CREATE TABLE IF NOT EXISTS dtm.trans_log (
id int NOT NULL DEFAULT NEXTVAL ('dtm.trans_log_seq'),
gid varchar(128) NOT NULL ,
branch_id varchar(128) DEFAULT NULL ,
action varchar(45) DEFAULT NULL ,
old_status varchar(45) NOT NULL DEFAULT '' ,
new_status varchar(45) NOT NULL ,
detail TEXT NOT NULL ,
create_time timestamp(0) DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
) ;

CREATE INDEX if not EXISTS gid ON dtm.trans_log (gid);
CREATE INDEX if not EXISTS create_time ON dtm.trans_log (create_time);

11 changes: 8 additions & 3 deletions examples/base_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/yedf/dtm/common"
"github.com/yedf/dtm/dtmcli"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

Expand Down Expand Up @@ -129,9 +130,13 @@ func BaseAddRoute(app *gin.Engine) {
if reqFrom(c).TransOutResult == dtmcli.ResultFailure {
return dtmcli.MapFailure, nil
}
gdb, err := gorm.Open(mysql.New(mysql.Config{
Conn: db,
}), &gorm.Config{})
var dia gorm.Dialector = nil
if dtmcli.DBDriver == dtmcli.DriverMysql {
dia = mysql.New(mysql.Config{Conn: db})
} else if dtmcli.DBDriver == dtmcli.DriverPostgres {
dia = postgres.New(postgres.Config{Conn: db})
}
gdb, err := gorm.Open(dia, &gorm.Config{})
if err != nil {
return nil, err
}
Expand Down
62 changes: 62 additions & 0 deletions examples/examples.postgres.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
CREATE SCHEMA if not exists dtm_busi /* SQLINES DEMO *** RACTER SET utf8mb4 */;
create SCHEMA if not exists dtm_barrier /* SQLINES DEMO *** RACTER SET utf8mb4 */;

drop table if exists dtm_busi.user_account;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create sequence if not exists dtm_busi.user_account_seq;

create table if not exists dtm_busi.user_account(
id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_busi.user_account_seq'),
user_id int UNIQUE ,
balance DECIMAL(10, 2) not null default '0',
create_time timestamp(0) DEFAULT now(),
update_time timestamp(0) DEFAULT now()
);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists create_idx on dtm_busi.user_account(create_time);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists update_idx on dtm_busi.user_account(update_time);

TRUNCATE dtm_busi.user_account;
insert into dtm_busi.user_account (user_id, balance) values (1, 10000), (2, 10000);

drop table if exists dtm_busi.user_account_trading;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create sequence if not exists dtm_busi.user_account_trading_seq;

create table if not exists dtm_busi.user_account_trading( -- SQLINES DEMO *** �冻结的金额
id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_busi.user_account_trading_seq'),
user_id int UNIQUE ,
trading_balance DECIMAL(10, 2) not null default '0',
create_time timestamp(0) DEFAULT now(),
update_time timestamp(0) DEFAULT now()
);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists create_idx on dtm_busi.user_account_trading(create_time);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists update_idx on dtm_busi.user_account_trading(update_time);

TRUNCATE dtm_busi.user_account_trading;
insert into dtm_busi.user_account_trading (user_id, trading_balance) values (1, 0), (2, 0);


drop table if exists dtm_barrier.barrier;
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create sequence if not exists dtm_barrier.barrier_seq;

create table if not exists dtm_barrier.barrier(
id int PRIMARY KEY DEFAULT NEXTVAL ('dtm_barrier.barrier_seq'),
trans_type varchar(45) default '' ,
gid varchar(128) default'',
branch_id varchar(128) default '',
branch_type varchar(45) default '',
reason varchar(45) default '' ,
result varchar(2047) default null ,
create_time timestamp(0) DEFAULT now(),
update_time timestamp(0) DEFAULT now(),
UNIQUE (gid, branch_id, branch_type)
);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists create_idx on dtm_barrier.barrier(create_time);
-- SQLINES LICENSE FOR EVALUATION USE ONLY
create index if not exists update_idx on dtm_barrier.barrier(update_time);
Loading

0 comments on commit 9518a93

Please sign in to comment.