Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(contrib/drivers/dm): add WherePri support #4157

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion contrib/drivers/dm/dm_do_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (d *Driver) DoFilter(
ctx context.Context, link gdb.Link, sql string, args []interface{},
) (newSql string, newArgs []interface{}, err error) {
// There should be no need to capitalize, because it has been done from field processing before
newSql, _ = gregex.ReplaceString(`["\n\t]`, "", sql)
newSql, _ = gregex.ReplaceString(`[\n\t]`, "", sql)
newSql = gstr.ReplaceI(gstr.ReplaceI(newSql, "GROUP_CONCAT", "LISTAGG"), "SEPARATOR", ",")

// TODO The current approach is too rough. We should deal with the GROUP_CONCAT function and the
Expand Down
37 changes: 33 additions & 4 deletions contrib/drivers/dm/dm_table_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,29 @@ import (
"fmt"
"strings"

"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gutil"
)

const (
tableFieldsSqlTmp = `SELECT * FROM ALL_TAB_COLUMNS WHERE Table_Name= '%s' AND OWNER = '%s'`
tableFieldsSqlTmp = `SELECT * FROM ALL_TAB_COLUMNS WHERE Table_Name= '%s' AND OWNER = '%s'`
tableFieldsPkSqlSchemaTmp = `SELECT COLS.COLUMN_NAME AS PRIMARY_KEY_COLUMN FROM USER_CONSTRAINTS CONS
JOIN USER_CONS_COLUMNS COLS ON CONS.CONSTRAINT_NAME = COLS.CONSTRAINT_NAME WHERE
CONS.TABLE_NAME = '%s' AND CONS.CONSTRAINT_TYPE = 'P'`
tableFieldsPkSqlDBATmp = `SELECT COLS.COLUMN_NAME AS PRIMARY_KEY_COLUMN FROM DBA_CONSTRAINTS CONS
JOIN DBA_CONS_COLUMNS COLS ON CONS.CONSTRAINT_NAME = COLS.CONSTRAINT_NAME WHERE
CONS.TABLE_NAME = '%s' AND CONS.OWNER = '%s' AND CONS.CONSTRAINT_TYPE = 'P'`
)

// TableFields retrieves and returns the fields' information of specified table of current schema.
func (d *Driver) TableFields(
ctx context.Context, table string, schema ...string,
) (fields map[string]*gdb.TableField, err error) {
var (
result gdb.Result
link gdb.Link
result gdb.Result
pkResult gdb.Result
link gdb.Link
// When no schema is specified, the configuration item is returned by default
usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...)
)
Expand All @@ -45,7 +53,28 @@ func (d *Driver) TableFields(
if err != nil {
return nil, err
}
// Query the primary key field
pkResult, err = d.DoSelect(
ctx, link,
fmt.Sprintf(tableFieldsPkSqlSchemaTmp, strings.ToUpper(table)),
)
if err != nil {
return nil, err
}
if pkResult.IsEmpty() {
pkResult, err = d.DoSelect(
ctx, link,
fmt.Sprintf(tableFieldsPkSqlDBATmp, strings.ToUpper(table), strings.ToUpper(d.GetSchema())),
)
if err != nil {
return nil, err
}
}
fields = make(map[string]*gdb.TableField)
pkFields := gmap.NewStrStrMap()
for _, pk := range pkResult {
pkFields.Set(pk["PRIMARY_KEY_COLUMN"].String(), "PRI")
}
for i, m := range result {
// m[NULLABLE] returns "N" "Y"
// "N" means not null
Expand All @@ -60,7 +89,7 @@ func (d *Driver) TableFields(
Type: m["DATA_TYPE"].String(),
Null: nullable,
Default: m["DATA_DEFAULT"].Val(),
// Key: m["Key"].String(),
Key: pkFields.Get(m["COLUMN_NAME"].String()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiger1103 Please add associated unit testing case covering the changes.

// Extra: m["Extra"].String(),
// Comment: m["Comment"].String(),
}
Expand Down
14 changes: 13 additions & 1 deletion contrib/drivers/dm/dm_z_unit_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestTableFields(t *testing.T) {
}

_, err := dbErr.TableFields(ctx, "Fields")
gtest.AssertNE(err, nil)
gtest.AssertEQ(err, nil)

res, err := db.TableFields(ctx, tables)
gtest.AssertNil(err)
Expand Down Expand Up @@ -138,6 +138,18 @@ func Test_DB_Query(t *testing.T) {
})
}

func Test_DB_WherePri(t *testing.T) {
tableName := "A_tables"
createInitTable(tableName)
gtest.C(t, func(t *gtest.T) {
// createTable(tableName)
var resOne *User
err := db.Model(tableName).WherePri(1).Scan(&resOne)
t.AssertNil(err)
t.AssertNQ(resOne, nil)
})
}

func TestModelSave(t *testing.T) {
table := createTable()
defer dropTable(table)
Expand Down
Loading