diff --git a/example/dal/model/user.go b/example/dal/model/user.go index 007aca4..b9c1ed0 100644 --- a/example/dal/model/user.go +++ b/example/dal/model/user.go @@ -4,7 +4,9 @@ // @Description: package model -import "time" +import ( + "time" +) // User db model struct, will be mapped to row of db by ORM. type User struct { @@ -34,9 +36,9 @@ type UserWhere struct { type UserUpdate struct { ID *int64 `sql_field:"id"` Name *string `sql_field:"name"` - Balance *int64 `gorm:"column:balance"` - BalanceAdd *int64 `gorm:"column:balance" sql_expr:"+"` - BalanceMinus *int64 `gorm:"column:balance" sql_expr:"-"` + Balance *int64 `sql_field:"balance"` + BalanceAdd *int64 `sql_field:"balance" sql_expr:"+"` + BalanceMinus *int64 `sql_field:"balance" sql_expr:"-"` UpdateTime *time.Time `sql_field:"update_time"` Deleted *bool `sql_field:"deleted"` } diff --git a/example/main.go b/example/main.go index c4e477e..dfeea3d 100644 --- a/example/main.go +++ b/example/main.go @@ -7,28 +7,35 @@ import ( "github.com/dirac-lee/gdal/example/dal" "github.com/dirac-lee/gdal/example/dal/model" "github.com/dirac-lee/gdal/gutil/gptr" + "github.com/luci/go-render/render" "gorm.io/driver/mysql" "gorm.io/gorm" + "log" "os" "time" ) +var ( + DB *gorm.DB +) + const ( - DemoDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" + MysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" ) func main() { - debug := os.Getenv("DEBUG") - fmt.Println(debug) - db, err := gorm.Open(mysql.Open(DemoDSN)) + var err error + DB, err = gorm.Open(mysql.Open(MysqlDSN)) if err != nil { panic(err) } + RunMigrations() + ctx := context.Background() - userDAL := dal.NewUserDAL(db) + userDAL := dal.NewUserDAL(DB) - { //创建单条记录 + { // create single record now := time.Now() po := model.User{ ID: 110, @@ -38,10 +45,11 @@ func main() { UpdateTime: now, Deleted: false, } - userDAL.Create(ctx, &po) + err := userDAL.Create(ctx, &po) + fmt.Println(err) } - { // 创建多条记录 + { // multiple create now := time.Now() pos := []*model.User{ { @@ -61,71 +69,72 @@ func main() { Deleted: false, }, } - userDAL.MCreate(ctx, &pos) - } - - { // 物理删除 - where := &model.UserWhere{ - IDIn: []int64{110, 120}, - } - userDAL.Delete(ctx, where) - } - - { // 通过ID物理删除 - userDAL.DeleteByID(ctx, 130) + numCreated, err := userDAL.MCreate(ctx, &pos) + fmt.Println(numCreated) + fmt.Println(err) } - { // 更新 + { // update by where condition where := &model.UserWhere{ IDIn: []int64{110, 120}, } update := &model.UserUpdate{ BalanceAdd: gptr.Of[int64](10), } - userDAL.MUpdate(ctx, where, update) + numUpdate, err := userDAL.MUpdate(ctx, where, update) + fmt.Println(numUpdate) + fmt.Println(err) } - { // 通过 ID 更新 + { // update by id update := &model.UserUpdate{ BalanceMinus: gptr.Of[int64](20), } - userDAL.UpdateByID(ctx, 130, update) + err := userDAL.UpdateByID(ctx, 130, update) + fmt.Println(err) } - { // 通用查询 + { // general query var pos []*model.User where := &model.UserWhere{ NameLike: gptr.Of("dirac"), } - userDAL.Find(ctx, &pos, where, gdal.WithDebug()) + err := userDAL.Find(ctx, &pos, where, gdal.WithDebug()) + fmt.Println(err) + fmt.Println(render.Render(pos)) } - { // 查询多条记录 + { // multiple query where := &model.UserWhere{ IDIn: []int64{110, 120}, } pos, err := userDAL.MQuery(ctx, where, gdal.WithDebug(), gdal.WithMaster()) - println(pos, err) + fmt.Println(err) + fmt.Println(render.Render(pos)) } - { // 分页查询1 + { // query by paging: method 1 where := &model.UserWhere{ IDIn: []int64{110, 120}, } pos, total, err := userDAL.MQueryByPaging(ctx, where, gptr.Of[int64](5), nil, gptr.Of("create_time desc")) - println(pos, total, err) + fmt.Println(err) + fmt.Println(total) + fmt.Println(render.Render(pos)) } - { // 分页查询2 + { // query by paging: method 2 where := &model.UserWhere{ IDIn: []int64{110, 120}, } pos, total, err := userDAL.MQueryByPagingOpt(ctx, where, gdal.WithLimit(5), gdal.WithOrder("create_time desc"), gdal.WithDebug()) - println(pos, total, err) + fmt.Println(err) + fmt.Println(total) + fmt.Println(render.Render(pos)) } - { - finalErr := db.Transaction(func(tx *gorm.DB) error { + { // transaction + finalErr := DB.Transaction(func(tx *gorm.DB) error { update := &model.UserUpdate{ BalanceMinus: gptr.Of[int64](20), @@ -142,6 +151,45 @@ func main() { return nil // commit }) - println(finalErr) + fmt.Println(finalErr) + } + + { // 物理删除 + where := &model.UserWhere{ + IDIn: []int64{110, 120}, + } + numDeleted, err := userDAL.Delete(ctx, where) + fmt.Println(numDeleted) + fmt.Println(err) + } + + { // 通过ID物理删除 + numDeleted, err := userDAL.DeleteByID(ctx, 130) + fmt.Println(numDeleted) + fmt.Println(err) + } +} + +func RunMigrations() { + var err error + allModels := []interface{}{&model.User{}} + + DB.Migrator().DropTable("user_friends", "user_speaks") + + if err = DB.Migrator().DropTable(allModels...); err != nil { + log.Printf("Failed to drop table, got error %v\n", err) + os.Exit(1) + } + + if err = DB.AutoMigrate(allModels...); err != nil { + log.Printf("Failed to auto migrate, but got error %v\n", err) + os.Exit(1) + } + + for _, m := range allModels { + if !DB.Migrator().HasTable(m) { + log.Printf("Failed to create table for %#v\n", m) + os.Exit(1) + } } } diff --git a/go.mod b/go.mod index 06359e3..4af2c35 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/bytedance/mockey v1.2.4 + github.com/luci/go-render v0.0.0-20160219211803-9a04cc21af0f github.com/smartystreets/goconvey v1.8.0 gorm.io/driver/mysql v1.5.1 gorm.io/driver/postgres v1.5.2 diff --git a/go.sum b/go.sum index 421aa65..9c1653b 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/luci/go-render v0.0.0-20160219211803-9a04cc21af0f h1:WVPqVsbUsrzAebTEgWRAZMdDOfkFx06iyhbIoyMgtkE= +github.com/luci/go-render v0.0.0-20160219211803-9a04cc21af0f/go.mod h1:aS446i8akEg0DAtNKTVYpNpLPMc0SzsZ0RtGhjl0uFM= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= diff --git a/gutil/genv/env.go b/gutil/genv/env.go index 0fc2b3b..94ac92f 100644 --- a/gutil/genv/env.go +++ b/gutil/genv/env.go @@ -6,6 +6,9 @@ package genv import "os" +// InDebugEnv +// +// @Description: where has "DEBUG" environment. func InDebugEnv() bool { return len(os.Getenv("DEBUG")) > 0 } diff --git a/gutil/greflect/greflect.go b/gutil/greflect/greflect.go index 875a36e..263b23d 100644 --- a/gutil/greflect/greflect.go +++ b/gutil/greflect/greflect.go @@ -6,14 +6,38 @@ import ( ) // Implements -// @Description: 类型 T 是否实现了 Interface 接口 -// @return Interface: 如果实现了,返回填入类型 T 零值的该接口;如果没实现,返回 nil -// @return bool: 是否了实现该接口 +// +// @Description: whether type `T` implements `Interface` +// +// @return Interface: return zero value of `T` if implements; otherwise, nil +// +// @return bool: whether implements func Implements[Interface any](v any) (Interface, bool) { i, ok := v.(Interface) return i, ok } +// GetElemValueTypeOfPtr +// +// @Description: get the element struct Value and Type if `rv` is a pointer to struct; otherwise, return `rv`'s. +// +// @param rv: pointer (maybe deep) to struct +// +// @return reflect.Value: element struct Value +// +// @return reflect.Type: element struct Type +// +// @return error: when `rv` is invalid or element is not a struct +// +// @example +// +// u := &User{ ID: 110, Name: "Bob" } +// rv := reflect.ValueOf(&u) +// elemRv, elemRt, err := GetElemValueTypeOfPtr(rv) +// +// then `elemRv` is `User{ ID: 110, Name: "Bob" }` +// `elemRt` is User +// `err` is nil func GetElemValueTypeOfPtr(rv reflect.Value) (reflect.Value, reflect.Type, error) { rv, err := GetElemValueOfPtr(rv) if err != nil { @@ -24,13 +48,20 @@ func GetElemValueTypeOfPtr(rv reflect.Value) (reflect.Value, reflect.Type, error // GetElemValueOfPtr // -// @Description: 获取指针底层 struct 数据类型。 +// @Description: get the element struct Value if `rv` is a pointer to struct; otherwise, return `rv`'s. // -// @param rv: 底层 struct 数据类型 +// @param rv: pointer (maybe deep) to struct // -// @return reflect.Value: +// @return reflect.Value: element struct Value // // @example +// +// u := &User{ ID: 110, Name: "Bob" } +// rv := reflect.ValueOf(&u) +// elemRv, err := GetElemValueOfPtr(rv) +// +// then `elemRv` is `User{ ID: 110, Name: "Bob" }` +// `err` is nil func GetElemValueOfPtr(rv reflect.Value) (reflect.Value, error) { if !rv.IsValid() { return rv, gerror.InvalidReflectValueErr(rv) @@ -49,12 +80,23 @@ func GetElemValueOfPtr(rv reflect.Value) (reflect.Value, error) { // // @Description: 获取指针、切片、数组底层 struct 数据类型。 // -// @param rt: 底层 struct 数据类型 +// @param rt: embedding of pointer (maybe deep), slice or array to struct. // -// @return reflect.Type: +// @return reflect.Type: element struct Type // // @example +// +// u := []User{{ID: 110, Name: "Bob"}, {ID: 120, Name: "Dirac"}} +// rv := reflect.TypeOf(&u) +// elemRt, err := GetElemStructType(rv) +// +// then `elemRt` is `User` +// `err` is nil func GetElemStructType(rt reflect.Type) (reflect.Type, error) { + type User struct { + ID int64 + Name string + } switch rt.Kind() { case reflect.Struct: return rt, nil diff --git a/gutil/gsql/update.go b/gutil/gsql/update.go index 084554f..c25f3f3 100644 --- a/gutil/gsql/update.go +++ b/gutil/gsql/update.go @@ -11,28 +11,28 @@ import ( // BuildSQLUpdate // -// @Description: 将 MUpdate model struct 编译为 sql 语句中 update 的字典 +// @Description: build Update struct into sql update map // -// @param update: MUpdate model +// @param update: Update struct // -// @return m: +// @return m: map of sql update // // @return err: // // @example: // -// // model 表 +// // model of table_abc // type TableAbc struct { // ID int64 `gorm:"column:id"` // Name string `gorm:"column:name"` // Age int `gorm:"column:p_age"` // } // -// func (Campaign) TableName() string { +// func (TableAbc) TableName() string { // return "table_abc" // } // -// // 需要更新的字段 +// // fields need to update. ⚠️ WARNING: must be pointers // type TableAbcUpdate struct { // Name *string `sql_field:"name"` // Age *int `sql_field:"p_age"` @@ -43,7 +43,7 @@ import ( // Name: &name // } // -// // 下面即 sql: update table_abc set name="byte-er" where id = 1 +// // SQL: update table_abc set name="byte-er" where id = 1 // attrs, err := BuildSQLUpdate(attrs) // if err != nil{ // // do something @@ -70,21 +70,23 @@ func BuildSQLUpdate(update any) (map[string]any, error) { return m, nil } -// 遍历 field,将非 nil 的值拼到 map 中 +// fillSQLUpdateFieldMap walk through all the fields in `rv`,insert non-zero fields into map. +// ⚠️ WARNING: empty slice []T{} is treated as zero value. func fillSQLUpdateFieldMap(rv reflect.Value, st *sqlType) (map[string]any, error) { m := make(map[string]any) for _, name := range st.Names { - column := st.ColumnsMap[name] // 前置函数已经检查过,一定存在 + column := st.ColumnsMap[name] // must be found, guaranteed by previous operations data := rv.FieldByName(column.Name) - // 字段的值是 nil 直接忽略 不做处理 + // skip nil if data.Kind() == reflect.Ptr && data.IsNil() { continue } + // only supports one-level pointers if data.Kind() == reflect.Ptr { data = data.Elem() } if column.Expr != "" { - updater := updaterMap[column.Expr] // 前置函数已经检查过,一定存在 + updater := updaterMap[column.Expr] // must be found, guaranteed by previous operations if updaterResult := updater(column.Field, data.Interface()); updaterResult.SQL != "" { m[column.Field] = updaterResult } @@ -96,9 +98,10 @@ func fillSQLUpdateFieldMap(rv reflect.Value, st *sqlType) (map[string]any, error return m, nil } -// SQLUpdater update语句生成器 +// SQLUpdater update SQL generator type SQLUpdater func(field string, data any) clause.Expr +// updaterMap support `+`, `-` and `merge_json` so far var updaterMap = map[string]SQLUpdater{ "+": func(field string, data any) clause.Expr { return gorm.Expr(field+" + ?", data) @@ -122,6 +125,9 @@ var updaterMap = map[string]SQLUpdater{ }, } +// isMergeJSONStruct +// +// @Description: whether `v` can be a struct or a pointer to struct func isMergeJSONStruct(v any) bool { vt := reflect.TypeOf(v) if vt.Kind() == reflect.Ptr { @@ -130,6 +136,9 @@ func isMergeJSONStruct(v any) bool { return vt.Kind() == reflect.Struct } +// mergeJSONStructToJSONMap +// +// @Description: convert struct to map by tag `json` func mergeJSONStructToJSONMap(v any) (map[string]any, error) { vt := reflect.TypeOf(v) vv := reflect.ValueOf(v) diff --git a/gutil/gsql/where.go b/gutil/gsql/where.go index cefaa91..bbd9f76 100644 --- a/gutil/gsql/where.go +++ b/gutil/gsql/where.go @@ -178,6 +178,9 @@ func buildSQLWhereWithAndOption(rv reflect.Value, rt reflect.Type) (query string } // 遍历 field,使用 and 拼接 where 语句 + +// fillSQLUpdateFieldMap walk through all the fields in `rv`, parsed to single where conditions, then join them with `AND`. +// ⚠️ WARNING: empty slice []T{} is treated as zero value. func fillSQLWhereCondition(rv reflect.Value, rt reflect.Type) (query string, args []any, err error) { args = []any{} qq := new(strings.Builder)