Skip to content

Commit

Permalink
Merge pull request #2458 from actiontech/issue-2457
Browse files Browse the repository at this point in the history
MYSQL规则 避免对条件字段使用函数操作 和 不建议在WHERE条件中使用与过滤字段不一致的数据类型 审核sql panic报错
  • Loading branch information
ColdWaterLW committed Jun 19, 2024
2 parents bc947e9 + 8d6963e commit 3f884f0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 31 deletions.
27 changes: 27 additions & 0 deletions sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,18 @@ select v1 from exist_db.exist_tb_1 where v2 = "3"
`,
newTestResult(),
)

runSingleRuleInspectCase(rule, t, "select: without from condition", DefaultMysqlInspect(), `select 1`, newTestResult())

runSingleRuleInspectCase(rule, t, "select: without where condition", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1`, newTestResult())

runSingleRuleInspectCase(rule, t, "select: next select with function", DefaultMysqlInspect(), `select * from (select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3") as t1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))

runSingleRuleInspectCase(rule, t, "select union select 1", DefaultMysqlInspect(), `select 1 union select 1`, newTestResult())

runSingleRuleInspectCase(rule, t, "select: union select", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3" union select * from exist_db.exist_tb_1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))

runSingleRuleInspectCase(rule, t, "union next select", DefaultMysqlInspect(), `select * from exist_db.exist_tb_1 union all select * from (select * from exist_db.exist_tb_1 where nvl(v2,"0") = "3") as t1`, newTestResult().addResult(rulepkg.DMLCheckWhereExistFunc))
}

func Test_DDLCheckCreateTimeColumn(t *testing.T) {
Expand Down Expand Up @@ -3027,6 +3039,21 @@ select v1 from exist_db.exist_tb_1 where id = 3;
`,
newTestResult(),
)

runSingleRuleInspectCase(rule, t, "select: not exist from condition", DefaultMysqlInspect(), `select 1;`, newTestResult())

runSingleRuleInspectCase(rule, t, "select: not exist where condition", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1;`, newTestResult())

runSingleRuleInspectCase(rule, t, "select: nest select", DefaultMysqlInspect(), `select s.* from (select v1 from exist_db.exist_tb_1 where id = "3") s`,
newTestResult().addResult(rulepkg.DMLCheckWhereExistImplicitConversion))

runSingleRuleInspectCase(rule, t, "select: nest select", DefaultMysqlInspect(), `select s.* from (select v1 from exist_db.exist_tb_1 where id = 3) s`, newTestResult())

runSingleRuleInspectCase(rule, t, "UNION: union all select", DefaultMysqlInspect(), `select 1 union all select 1`, newTestResult())

runSingleRuleInspectCase(rule, t, "UNION: union nest select", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1 union select s.v1 from (select v1 from exist_db.exist_tb_1 where v1 = "3") s`, newTestResult())

runSingleRuleInspectCase(rule, t, "UNION: union nest select", DefaultMysqlInspect(), `select v1 from exist_db.exist_tb_1 union select s.v1 from (select v1 from exist_db.exist_tb_1 where v1 = 3) s`, newTestResult().addResult(rulepkg.DMLCheckWhereExistImplicitConversion))
}

func TestCheckMultiSelectWhereExistImplicitConversion(t *testing.T) {
Expand Down
82 changes: 51 additions & 31 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -4811,21 +4811,39 @@ func checkDMLWithBatchInsertMaxLimits(input *RuleHandlerInput) error {

func checkWhereExistFunc(input *RuleHandlerInput) error {
tables := []*ast.TableName{}
switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if stmt.Where != nil {
tableSources := util.GetTableSources(stmt.From.TableRefs)
hasExistFunc := func(stmt *ast.SelectStmt) bool {
selectExtractor := util.SelectStmtExtractor{}
stmt.Accept(&selectExtractor)
for _, selectStmt := range selectExtractor.SelectStmts {
if selectStmt.Where == nil || selectStmt.From == nil {
continue
}

tableSources := util.GetTableSources(selectStmt.From.TableRefs)
// not select from table statement
if len(tableSources) < 1 {
break
continue
}

for _, tableSource := range tableSources {
switch source := tableSource.Source.(type) {
case *ast.TableName:
tables = append(tables, source)
}
}
checkExistFunc(input.Ctx, input.Rule, input.Res, tables, stmt.Where)

if checkExistFunc(input.Ctx, input.Rule, input.Res, tables, selectStmt.Where) {
return true
}
}

return false
}

switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if hasExistFunc(stmt) {
break
}
case *ast.UpdateStmt:
if stmt.Where != nil {
Expand All @@ -4843,18 +4861,8 @@ func checkWhereExistFunc(input *RuleHandlerInput) error {
checkExistFunc(input.Ctx, input.Rule, input.Res, util.GetTables(stmt.TableRefs.TableRefs), stmt.Where)
}
case *ast.UnionStmt:
for _, ss := range stmt.SelectList.Selects {
tableSources := util.GetTableSources(ss.From.TableRefs)
if len(tableSources) < 1 {
continue
}
for _, tableSource := range tableSources {
switch source := tableSource.Source.(type) {
case *ast.TableName:
tables = append(tables, source)
}
}
if checkExistFunc(input.Ctx, input.Rule, input.Res, tables, ss.Where) {
for _, selectStmt := range stmt.SelectList.Selects {
if hasExistFunc(selectStmt) {
break
}
}
Expand Down Expand Up @@ -4887,15 +4895,31 @@ func checkExistFunc(ctx *session.Context, rule driverV2.Rule, res *driverV2.Audi
}

func checkWhereColumnImplicitConversion(input *RuleHandlerInput) error {
switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if stmt.Where != nil {
tableSources := util.GetTableSources(stmt.From.TableRefs)
// not select from table statement
hasWhereColumnImplicitConversionFunc := func(stmt *ast.SelectStmt) bool {
selectExtractor := util.SelectStmtExtractor{}
stmt.Accept(&selectExtractor)
for _, selectStmt := range selectExtractor.SelectStmts {
if selectStmt.From == nil || selectStmt.Where == nil {
continue
}

tableSources := util.GetTableSources(selectStmt.From.TableRefs)
if len(tableSources) < 1 {
break
continue
}
checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, stmt.Where)

if checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, selectStmt.Where) {
return true
}
}

return false
}

switch stmt := input.Node.(type) {
case *ast.SelectStmt:
if hasWhereColumnImplicitConversionFunc(stmt) {
break
}
case *ast.UpdateStmt:
if stmt.Where != nil {
Expand All @@ -4908,12 +4932,8 @@ func checkWhereColumnImplicitConversion(input *RuleHandlerInput) error {
checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, stmt.Where)
}
case *ast.UnionStmt:
for _, ss := range stmt.SelectList.Selects {
tableSources := util.GetTableSources(ss.From.TableRefs)
if len(tableSources) < 1 {
continue
}
if checkWhereColumnImplicitConversionFunc(input.Ctx, input.Rule, input.Res, tableSources, ss.Where) {
for _, selectStmt := range stmt.SelectList.Selects {
if hasWhereColumnImplicitConversionFunc(selectStmt) {
break
}
}
Expand Down

0 comments on commit 3f884f0

Please sign in to comment.