diff --git a/main.go b/main.go index b668c2e..833e1f2 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ type column struct { Name string Type string IsNullable string - Default string + Default interface{} After string } @@ -81,7 +81,7 @@ func main() { log.Fatalln("open file error !") } // 创建一个日志对象 - dLog = log.New(logFile, "[Info]", log.LstdFlags|log.Lshortfile) + dLog = log.New(logFile, "[Info]", log.LstdFlags) //|log.Lshortfile) //配置一个日志格式的前缀 dLog.SetPrefix("[Info]") //配置log的Flag参数 @@ -189,7 +189,7 @@ func TriggerDiff(db1, db2 *sql.DB, schema1, schema2 string) bool { dLog.Printf("两个数据库不同的触发器,共有%d个,分别是:%s", len(dt), dt) return false } - dLog.Printf("两个数据库触发器相同") + // dLog.Printf("两个数据库触发器相同") return true } @@ -234,7 +234,7 @@ func FunctionDiff(db1, db2 *sql.DB, schema1, schema2 string) bool { dLog.Printf("两个数据库不同的函数,共有%d个,分别是:%s", len(dt), dt) return false } - dLog.Printf("两个数据库函数相同") + // dLog.Printf("两个数据库函数相同") return true } @@ -262,6 +262,29 @@ func getFunctionName(s *sql.DB, schema string) (ts []string, err error) { return } +func genAlterSql(t string, col column) string { + var after string + if col.After != "" { + after = fmt.Sprintf(" AFTER `%s`", col.After) + } + + var isNull string + if col.IsNullable == "YES" { + isNull = " NULL" + } else { + isNull = " NOT NULL" + } + + var defaultValue string + if col.Default == nil { + defaultValue = " DEFAULT NULL" + } else if col.Default != "" { + defaultValue = fmt.Sprintf(" DEFAULT '%s'", col.Default) + } + + return fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` %s%s%s%s;", t, col.Name, col.Type, isNull, defaultValue, after) +} + // ColumnDiff 对比函数的不同 func ColumnDiff(db1, db2 *sql.DB, schema1, schema2 string, table []string) { for _, t := range table { @@ -278,23 +301,14 @@ func ColumnDiff(db1, db2 *sql.DB, schema1, schema2 string, table []string) { // dLog.Printf("两个数据库%s表,有不同的列,共有%d个,分别是:%s", t, len(dt), dt) col1, col2 := columnDiff(columnName1, columnName2) - dLog.Printf("%s数据库%s表,列不相同", schema1, t) + dLog.Printf("%s数据库%s表,列不相同:%d", schema1, t, len(col1)) for _, col := range col1 { - // - var after string - if col.After != "" { - after = fmt.Sprintf("AFTER %s", col.After) - } - dLog.Printf("ALTER TABLE %s ADD COLUMN %s %s %s", t, col.Name, col.Type, after) + dLog.Printf(genAlterSql(t, col)) } - dLog.Printf("%s数据库%s表,列不相同", schema2, t) + dLog.Printf("%s数据库%s表,列不相同:%d", schema2, t, len(col2)) for _, col := range col2 { - var after string - if col.After != "" { - after = fmt.Sprintf("AFTER %s", col.After) - } - dLog.Printf("ALTER TABLE %s ADD COLUMN %s %s %s", t, col.Name, col.Type, after) + dLog.Printf(genAlterSql(t, col)) } } else { //dLog.Printf("两个数据库%s表,列相同", t) @@ -303,7 +317,7 @@ func ColumnDiff(db1, db2 *sql.DB, schema1, schema2 string, table []string) { } func getColumnName(s *sql.DB, schema, table string) (ts []column, err error) { - stm, perr := s.Prepare("select COLUMN_NAME,column_type,column_type,is_nullable from information_schema.columns where TABLE_SCHEMA=? and TABLE_NAME=? order by ordinal_position asc") + stm, perr := s.Prepare("select COLUMN_NAME,column_type,column_default,is_nullable from information_schema.columns where TABLE_SCHEMA=? and TABLE_NAME=? order by ordinal_position asc") if perr != nil { err = perr return @@ -322,7 +336,7 @@ func getColumnName(s *sql.DB, schema, table string) (ts []column, err error) { for q.Next() { var column_name string var column_type string - var column_default string + var column_default interface{} var is_nullable string if err := q.Scan(&column_name, &column_type, &column_default, &is_nullable); err != nil { @@ -362,7 +376,7 @@ func IndexDiff(db1, db2 *sql.DB, schema1, schema2 string, table []string) { dt := diffName(indexName1, indexName2) dLog.Printf("两个数据库%s表,有不同的索引,共有%d个,分别是:%s", t, len(dt), dt) } else { - dLog.Printf("两个数据库%s表,索引相同", t) + // dLog.Printf("两个数据库%s表,索引相同", t) } } }