From d7d751572d5bf235164ac3b5c33ccae77fff7f3c Mon Sep 17 00:00:00 2001 From: Felix Sun Date: Sat, 20 Apr 2019 13:02:19 +0800 Subject: [PATCH] Improve Get --- README.md | 5 +- example_test.go | 9 ++-- get.go | 40 ++++++--------- map_slice_test.go | 4 +- populate_test.go | 14 ------ set.go | 5 +- struct_test.go | 121 ++++++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 148 insertions(+), 50 deletions(-) delete mode 100644 populate_test.go diff --git a/README.md b/README.md index 8b5188a..406a85e 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,9 @@ By given these structs } type Company struct { - Name string - Phone *Phone + Name string + Phone *Phone + Phone2 **Phone } type Department struct { diff --git a/example_test.go b/example_test.go index cb4b1c4..837ddc1 100644 --- a/example_test.go +++ b/example_test.go @@ -1,8 +1,10 @@ -package reflectutils +package reflectutils_test import ( "encoding/json" "fmt" + + . "github.com/sunfmin/reflectutils" ) // By given these structs @@ -24,8 +26,9 @@ func ExampleSet_0init() { } type Company struct { - Name string - Phone *Phone + Name string + Phone *Phone + Phone2 **Phone } type Department struct { diff --git a/get.go b/get.go index 5abb80a..fb6c062 100644 --- a/get.go +++ b/get.go @@ -1,7 +1,6 @@ package reflectutils import ( - "errors" "fmt" "reflect" "strings" @@ -19,36 +18,22 @@ func MustGet(i interface{}, name string) (value interface{}) { // Get value of a struct by path using reflect. func Get(i interface{}, name string) (value interface{}, err error) { - - v := reflect.ValueOf(i) - - if v.Kind() != reflect.Ptr { - err = errors.New("get object must be a pointer") - return - } - - for v.Elem().Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.IsNil() { - return - } - - sv := v.Elem() + // printv(i, name) if name == "" { - value = sv.Interface() + value = i return } + v := reflect.ValueOf(i) + var token *dotToken token, err = nextDot(name) if err != nil { return } - // printv(sv.Interface(), name, value) + sv := v if sv.Kind() == reflect.Map { // map must have string type @@ -72,7 +57,7 @@ func Get(i interface{}, name string) (value interface{}, err error) { mapElem.Set(existElem) } - value, err = Get(mapElem.Addr().Interface(), token.Left) + value, err = Get(mapElem.Interface(), token.Left) if err != nil { return } @@ -96,7 +81,7 @@ func Get(i interface{}, name string) (value interface{}, err error) { return } - value, err = Get(arrayElem.Addr().Interface(), token.Left) + value, err = Get(arrayElem.Interface(), token.Left) if err != nil { return } @@ -104,6 +89,14 @@ func Get(i interface{}, name string) (value interface{}, err error) { return } + if sv.Kind() != reflect.Struct { + for sv.Elem().Kind() == reflect.Ptr { + sv = sv.Elem() + } + + sv = sv.Elem() + } + if sv.Kind() == reflect.Struct { fv := sv.FieldByNameFunc(func(fname string) bool { return strings.EqualFold(fname, token.Field) @@ -113,8 +106,7 @@ func Get(i interface{}, name string) (value interface{}, err error) { err = NoSuchFieldError return } - - value, err = Get(fv.Addr().Interface(), token.Left) + value, err = Get(fv.Interface(), token.Left) return } diff --git a/map_slice_test.go b/map_slice_test.go index 0b4b512..5628368 100644 --- a/map_slice_test.go +++ b/map_slice_test.go @@ -1,7 +1,9 @@ -package reflectutils +package reflectutils_test import ( "testing" + + . "github.com/sunfmin/reflectutils" ) type mapTest struct { diff --git a/populate_test.go b/populate_test.go deleted file mode 100644 index 5647eb8..0000000 --- a/populate_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package reflectutils - -// import ( -// "testing" -// ) - -// func TestPopulate(t *testing.T) { -// var v1 *Person -// Populate(&v1) -// if v1.Company == nil { -// t.Errorf("not populated %+v", v1) -// } - -// } diff --git a/set.go b/set.go index f9bf188..f44d133 100644 --- a/set.go +++ b/set.go @@ -224,17 +224,16 @@ func nextDot(name string) (t *dotToken, err error) { return } -func printv(v interface{}, name interface{}, value string) { +func printv(v interface{}, name interface{}) { log.Println("=====") rv := reflect.ValueOf(v) log.Printf( - "\n\tname: %+v, \n\tv: %+v, \n\trv: %+v, \n\trv.Kind(): %+v, \n\trv.Type(): %+v, \n\trv.IsNil(): %+v, \n\trv.IsValid(): %+v", + "\n\tname: %+v, \n\tv: %+v, \n\trv: %+v, \n\trv.Kind(): %+v, \n\trv.Type(): %+v, \n\trv.IsValid(): %+v", name, v, rv, rv.Kind(), rv.Type(), - "", rv.IsValid(), ) log.Println("=====") diff --git a/struct_test.go b/struct_test.go index f1deda4..55bd74f 100644 --- a/struct_test.go +++ b/struct_test.go @@ -1,8 +1,11 @@ -package reflectutils +package reflectutils_test import ( "fmt" + "reflect" "testing" + + . "github.com/sunfmin/reflectutils" ) type Person struct { @@ -22,8 +25,9 @@ type Language struct { } type Company struct { - Name string - Phone *Phone + Name string + Phone *Phone + Phone2 **Phone `json:"-"` } type Department struct { @@ -198,3 +202,114 @@ func TestSetOtherPointers(t *testing.T) { t.Errorf("set failed %+v", v) } } + +var phone158 = &Phone{ + Number: "158", +} + +var getcases = []struct { + Name string + Value interface{} + ExpectedGetType string +}{ + { + Name: "", + Value: &Person{ + Company: &Company{ + Name: "The Plant", + }, + }, + ExpectedGetType: "*reflectutils_test.Person", + }, + { + Name: "Company", + Value: &Person{ + Company: &Company{ + Name: "The Plant", + }, + }, + ExpectedGetType: "*reflectutils_test.Company", + }, + { + Name: "Company.Phone2", + Value: &Person{ + Company: &Company{ + Phone2: &phone158, + }, + }, + ExpectedGetType: "**reflectutils_test.Phone", + }, + { + Name: "Company.Phone2.Number", + Value: &Person{ + Company: &Company{ + Phone2: &phone158, + }, + }, + ExpectedGetType: "string", + }, + { + Name: "Phones.Home", + Value: &Person{ + Phones: map[string]string{ + "Home": "158", + }, + }, + ExpectedGetType: "string", + }, + { + Name: "Languages.en_US.Name", + Value: &Person{ + Languages: map[string]Language{ + "en_US": Language{ + Name: "English", + }, + }, + }, + ExpectedGetType: "string", + }, + { + Name: "Projects[1].Name", + Value: &Person{ + Projects: []*Project{ + { + Name: "Top1", + }, + { + Name: "Top2", + }, + }, + }, + ExpectedGetType: "string", + }, +} + +func TestGet(t *testing.T) { + + for _, c := range getcases { + t.Run(c.Name, func(t2 *testing.T) { + v, err := Get(c.Value, c.Name) + if err != nil { + panic(err) + } + + typeName := reflect.ValueOf(v).Type().String() + if typeName != c.ExpectedGetType { + panic(fmt.Sprintf("expected is %v, but was %v", c.ExpectedGetType, typeName)) + } + }) + } + + var p = &Person{ + Company: &Company{ + Name: "The Plant 1", + }, + } + + c1 := MustGet(p, "Company").(*Company) + + c1.Name = "The Plant 2" + if p.Company.Name != c1.Name { + panic(fmt.Sprintf("expected is %v, but was %v", c1.Name, p.Company.Name)) + } +}