From 2d8f64ef5d88139d204303d4761af45047b9cd4c Mon Sep 17 00:00:00 2001 From: Takuya Ueda Date: Thu, 19 Nov 2020 23:45:41 +0900 Subject: [PATCH] Add Field --- types.go | 24 ++++++++++++++++++++++ types_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/types.go b/types.go index 46b9706..6039a58 100644 --- a/types.go +++ b/types.go @@ -131,6 +131,30 @@ func HasField(s *types.Struct, f *types.Var) bool { return false } +// Field returns field of the struct type. +// If the type is not struct or has not the field, +// Field returns -1, nil. +// If the type is a named type or a pointer type, +// Field calls itself recursively with +// an underlying type or an element type of pointer. +func Field(t types.Type, name string) (int, *types.Var) { + switch t := t.(type) { + case *types.Pointer: + return Field(t.Elem(), name) + case *types.Named: + return Field(t.Underlying(), name) + case *types.Struct: + for i := 0; i < t.NumFields(); i++ { + f := t.Field(i) + if f.Name() == name { + return i, f + } + } + } + + return -1, nil +} + func TypesInfo(info ...*types.Info) *types.Info { if len(info) == 0 { return nil diff --git a/types_test.go b/types_test.go index a539c33..ac9a812 100644 --- a/types_test.go +++ b/types_test.go @@ -115,3 +115,58 @@ func TestUnder(t *testing.T) { }) } } + +func TestField(t *testing.T) { + t.Parallel() + + lookup := func(pass *analysis.Pass, n string) (types.Type, error) { + _, obj := pass.Pkg.Scope().LookupParent(n, token.NoPos) + if obj == nil { + return nil, fmt.Errorf("does not find: %s", n) + } + return obj.Type(), nil + } + + cases := map[string]struct { + src string + typ string + field string + want int + }{ + "nomarl": {"type a struct{n int}", "a", "n", 0}, + "nofield": {"type a struct{n int}", "a", "m", -1}, + "empty": {"type a struct{}", "a", "n", -1}, + "two": {"type a struct{n, m int}", "a", "m", 1}, + "nonamed": {"var a struct{n, m int}", "a", "m", 1}, + "ptr": {"var a *struct{n, m int}", "a", "m", 1}, + "namednamed": {"type a struct{n int}; type b a", "b", "n", 0}, + "alias": {"type a struct{n int}; type b = a", "b", "n", 0}, + } + + for name, tt := range cases { + name, tt := name, tt + t.Run(name, func(t *testing.T) { + t.Parallel() + a := &analysis.Analyzer{ + Name: name + "Analyzer", + Run: func(pass *analysis.Pass) (interface{}, error) { + typ, err := lookup(pass, tt.typ) + if err != nil { + return nil, err + } + + got, _ := analysisutil.Field(typ, tt.field) + if tt.want != got { + return nil, fmt.Errorf("want %v but got %v", tt.want, got) + } + return nil, nil + }, + } + path := filepath.Join(name, name+".go") + dir := WriteFiles(t, map[string]string{ + path: fmt.Sprintf("package %s\n%s", name, tt.src), + }) + analysistest.Run(t, dir, a, name) + }) + } +}