diff --git a/table.go b/table.go index 3b3dbf4..0c9932b 100644 --- a/table.go +++ b/table.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "reflect" - "strconv" "strings" "time" @@ -61,7 +60,7 @@ func (d *defaultParser) ParseSQL(sql string) error { d.tables[tableName] = &Table{ Name: tableName, - Comment: create.Table.TableInfo.Comment, + Comment: getTableComment(create), ColumnTypes: d.getColumnTypes(create), Indexes: d.getIndexes(create), } @@ -142,6 +141,21 @@ func (d *defaultParser) ParseSQL(sql string) error { return nil } +func getTableComment(create *ast.CreateTableStmt) string { + if create == nil { + return "" + } + if create.Table.TableInfo != nil && create.Table.TableInfo.Comment != "" { + return create.Table.TableInfo.Comment + } + for _, tp := range create.Options { + if tp.Tp == ast.TableOptionComment { + return tp.StrValue + } + } + return "" +} + func (d *defaultParser) getColumnTypes(create *ast.CreateTableStmt) (cols []gorm.ColumnType) { if create == nil || len(create.Cols) == 0 { return nil @@ -216,19 +230,9 @@ func (*defaultParser) getColumnType(col *ast.ColumnDef) gorm.ColumnType { } if opt.Tp == ast.ColumnOptionDefaultValue { if v, ok := opt.Expr.(*test_driver.ValueExpr); ok { - dv := sql.NullString{ - Valid: true, + ct.DefaultValueValue = sql.NullString{ + Valid: true, String: fmt.Sprint(v.Datum.GetValue()), } - switch v.Datum.Kind() { - case test_driver.KindInt64: - dv.String = strconv.FormatInt(v.Datum.GetInt64(), 10) - case test_driver.KindUint64: - dv.String = strconv.FormatUint(v.Datum.GetUint64(), 10) - default: - dv.String = v.Datum.GetString() - } - - ct.DefaultValueValue = dv continue } diff --git a/tests/gen_test.go b/tests/gen_test.go index ed6502f..d0750be 100644 --- a/tests/gen_test.go +++ b/tests/gen_test.go @@ -29,7 +29,7 @@ func TestSqlGen(t *testing.T) { } fmt.Println(db.Migrator().GetTables()) - cts, err := db.Migrator().ColumnTypes("users") + cts, err := db.Migrator().ColumnTypes("credit_cards") if err != nil { fmt.Println(err.Error()) diff --git a/tests/sql/01_tables.sql b/tests/sql/01_tables.sql index 47aacd1..ca023d9 100644 --- a/tests/sql/01_tables.sql +++ b/tests/sql/01_tables.sql @@ -10,13 +10,13 @@ CREATE TABLE `credit_cards` ( `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, `created_at` datetime(3) DEFAULT NULL, `updated_at` datetime(3) DEFAULT NULL, - `deleted_at` datetime(3) DEFAULT NULL, + `deleted_at` datetime(3) DEFAULT '1970-01-01 08:01:00', `number` longtext, - `customer_refer` bigint(20) unsigned DEFAULT NULL, + `customer_refer` bigint(20) unsigned NOT NULL DEFAULT 1, `bank_id` bigint(20) unsigned DEFAULT NULL, PRIMARY KEY (`id`), KEY `idx_credit_cards_deleted_at` (`deleted_at`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='店铺用户表'; CREATE TABLE `customers` ( @@ -27,7 +27,7 @@ CREATE TABLE `customers` ( `bank_id` bigint(20) unsigned DEFAULT NULL, PRIMARY KEY (`id`), KEY `idx_customers_deleted_at` (`deleted_at`) -) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='店铺用户表'; CREATE TABLE `people` (