diff --git a/generator/generator.go b/generator/generator.go index e02a62f4..9d8b83d3 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -2,6 +2,7 @@ package generator import ( "bytes" + "fmt" "path/filepath" "sort" "strings" @@ -184,7 +185,7 @@ func NewGenerator(options Options) (*Generator, error) { options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`) } - methods, imports, err := findInterface(fs, srcPackageAST, options.InterfaceName) + methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName) if err != nil { return nil, errors.Wrap(err, "failed to parse interface declaration") } @@ -291,7 +292,7 @@ var errInterfaceNotFound = errors.New("interface type declaration not found") // findInterface looks for the interface declaration in the given directory // and returns a list of the interface's methods and a list of imports from the file // where interface type declaration was found -func findInterface(fs *token.FileSet, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) { +func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) { var found bool var types []*ast.TypeSpec var it *ast.InterfaceType @@ -317,7 +318,7 @@ func findInterface(fs *token.FileSet, p *ast.Package, interfaceName string) (met return nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName) } - methods, err = processInterface(fs, it, types, p.Name, imports) + methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports) if err != nil { return nil, nil, err } @@ -341,7 +342,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec { return result } -func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) { +func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) { if it.Methods == nil { return nil, nil } @@ -361,9 +362,9 @@ func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.Typ continue } case *ast.SelectorExpr: - embeddedMethods, err = processSelector(fs, v, imports) + embeddedMethods, err = processSelector(fs, currentPackage, v, imports) case *ast.Ident: - embeddedMethods, err = processIdent(fs, v, types, typesPrefix, imports) + embeddedMethods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports) } if err != nil { @@ -379,18 +380,18 @@ func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.Typ return methods, nil } -func processSelector(fs *token.FileSet, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) { +func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) { interfaceName := se.Sel.Name packageSelector := se.X.(*ast.Ident).Name - importPath, err := importPathByPrefix(imports, packageSelector) + importPath, err := findImportPathForName(packageSelector, imports, currentPackage) if err != nil { - return nil, errors.Wrapf(err, "unable to load %s.%s", packageSelector, interfaceName) + return nil, errors.Wrapf(err, "unable to find package %s", packageSelector) } - p, err := pkg.Load(importPath) - if err != nil { - return nil, errors.Wrap(err, "failed to load imported package") + p, ok := currentPackage.Imports[importPath] + if !ok { + return nil, fmt.Errorf("unable to find package %s", packageSelector) } astPkg, err := pkg.AST(fs, p) @@ -398,7 +399,7 @@ func processSelector(fs *token.FileSet, se *ast.SelectorExpr, imports []*ast.Imp return nil, errors.Wrap(err, "failed to import package") } - methods, _, err := findInterface(fs, astPkg, interfaceName) + methods, _, err := findInterface(fs, p, astPkg, interfaceName) return methods, err } @@ -431,7 +432,7 @@ func mergeMethods(ml1, ml2 methodsList) (methodsList, error) { var errEmbeddedInterfaceNotFound = errors.New("embedded interface not found") var errNotAnInterface = errors.New("embedded type is not an interface") -func processIdent(fs *token.FileSet, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) { +func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) { var embeddedInterface *ast.InterfaceType for _, t := range types { if t.Name.Name == i.Name { @@ -448,26 +449,25 @@ func processIdent(fs *token.FileSet, i *ast.Ident, types []*ast.TypeSpec, typesP return nil, errors.Wrap(errEmbeddedInterfaceNotFound, i.Name) } - return processInterface(fs, embeddedInterface, types, typesPrefix, imports) + return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports) } var errUnknownSelector = errors.New("unknown selector") -func importPathByPrefix(imports []*ast.ImportSpec, prefix string) (string, error) { +func findImportPathForName(name string, imports []*ast.ImportSpec, currentPackage *packages.Package) (string, error) { for _, i := range imports { - if i.Name != nil && i.Name.Name == prefix { + if i.Name != nil && i.Name.Name == name { return unquote(i.Path.Value), nil } } - for _, i := range imports { - p, err := pkg.Load(unquote(i.Path.Value)) - if err == nil && p.Name == prefix { - return p.PkgPath, nil + for path, pkg := range currentPackage.Imports { + if pkg.Name == name { + return path, nil } } - return "", errUnknownSelector + return "", errors.Wrapf(errUnknownSelector, name) } func unquote(s string) string { diff --git a/generator/generator_test.go b/generator/generator_test.go index 5695c8a0..d43bf386 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/tools/go/packages" ) func Test_unquote(t *testing.T) { @@ -51,77 +52,59 @@ func Test_unquote(t *testing.T) { } } -func Test_importPathByPrefix(t *testing.T) { +func Test_findImportPathForName(t *testing.T) { type args struct { + name string imports []*ast.ImportSpec - prefix string + cp *packages.Package } tests := []struct { name string args args - inspectErr func(*testing.T, error) - - want1 string - wantErr bool + want string + wantErr error }{ { - name: "prefix in import statement", - args: args{ - imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "myio"}, Path: &ast.BasicLit{Value: "io"}}}, - prefix: "myio", - }, - wantErr: false, - want1: "io", - }, - { - name: "failed to load package", + name: "path from import name", args: args{ - imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "myio"}, Path: &ast.BasicLit{Value: "unexisting_package"}}}, - prefix: "unexisting_package", + name: "pkg", + imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "pkg"}, Path: &ast.BasicLit{Value: "domain/pkgname"}}}, }, - wantErr: true, + want: "domain/pkgname", }, { - name: "success", + name: "path from package imports", args: args{ - imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}}, - prefix: "io", + name: "pkg", + cp: &packages.Package{ + Imports: map[string]*packages.Package{ + "domain/pkgname": { + Name: "pkg", + }, + }, + }, }, - want1: "io", - wantErr: false, + want: "domain/pkgname", }, { - name: "unknown prefix", + name: "not found", args: args{ - imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}}, - prefix: "unknown_prefix", + name: "pkg", + cp: &packages.Package{}, }, - inspectErr: func(t *testing.T, err error) { - assert.Equal(t, errUnknownSelector, err) - }, - wantErr: true, + wantErr: errors.Wrapf(errUnknownSelector, "pkg"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mc := minimock.NewController(t) - defer mc.Wait(time.Second) - - got1, err := importPathByPrefix(tt.args.imports, tt.args.prefix) - - assert.Equal(t, tt.want1, got1, "importPathByPrefix returned unexpected result") - - if tt.wantErr { - assert.Error(t, err) - if tt.inspectErr != nil { - tt.inspectErr(t, err) - } + path, err := findImportPathForName(tt.args.name, tt.args.imports, tt.args.cp) + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) } else { - assert.NoError(t, err) + assert.Equal(t, tt.want, path) } - }) } } @@ -179,7 +162,7 @@ func Test_processIdent(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, err := processIdent(tt.args.fs, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports) + got1, err := processIdent(tt.args.fs, nil, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports) assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result") @@ -266,6 +249,7 @@ func Test_mergeMethods(t *testing.T) { func Test_processSelector(t *testing.T) { type args struct { fs *token.FileSet + cp *packages.Package se *ast.SelectorExpr imports []*ast.ImportSpec } @@ -278,17 +262,19 @@ func Test_processSelector(t *testing.T) { inspectErr func(err error, t *testing.T) }{ { - name: "import not found", + name: "import with name not found", args: args{ se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknown"}, Sel: &ast.Ident{Name: "unknown"}}, + cp: &packages.Package{Imports: nil}, }, wantErr: true, }, { - name: "import failed", + name: "import not found", args: args{ se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknownpackage"}, Sel: &ast.Ident{Name: "Unknown"}}, - imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "unknownpackage"}, Path: &ast.BasicLit{Value: "unknown_path"}}}, + imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "unknown_path"}}}, + cp: &packages.Package{Imports: nil}, }, wantErr: true, }, @@ -298,6 +284,9 @@ func Test_processSelector(t *testing.T) { se: &ast.SelectorExpr{X: &ast.Ident{Name: "io"}, Sel: &ast.Ident{Name: "UnknownInterface"}}, imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}}, fs: token.NewFileSet(), + cp: &packages.Package{Imports: map[string]*packages.Package{ + "io": {}, + }}, }, wantErr: true, }, @@ -305,7 +294,7 @@ func Test_processSelector(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processSelector(tt.args.fs, tt.args.se, tt.args.imports) + got1, err := processSelector(tt.args.fs, tt.args.cp, tt.args.se, tt.args.imports) assert.Equal(t, tt.want1, got1, "processSelector returned unexpected result") @@ -324,6 +313,7 @@ func Test_processSelector(t *testing.T) { func Test_processInterface(t *testing.T) { type args struct { fs *token.FileSet + cp *packages.Package it *ast.InterfaceType types []*ast.TypeSpec typesPrefix string @@ -350,6 +340,7 @@ func Test_processInterface(t *testing.T) { name: "selector expression", args: args{ fs: token.NewFileSet(), + cp: &packages.Package{Imports: nil}, it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{ { Names: []*ast.Ident{{Name: "methodName"}}, @@ -390,7 +381,7 @@ func Test_processInterface(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got1, err := processInterface(tt.args.fs, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports) + got1, err := processInterface(tt.args.fs, tt.args.cp, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports) assert.Equal(t, tt.want1, got1, "processInterface returned unexpected result") @@ -462,7 +453,7 @@ func Test_findInterface(t *testing.T) { mc := minimock.NewController(t) defer mc.Wait(time.Second) - got1, _, err := findInterface(tt.args.fs, tt.args.p, tt.args.interfaceName) + got1, _, err := findInterface(tt.args.fs, nil, tt.args.p, tt.args.interfaceName) assert.Equal(t, tt.want1, got1, "findInterface returned unexpected result") diff --git a/pkg/package.go b/pkg/package.go index 09b94c01..b58765e3 100644 --- a/pkg/package.go +++ b/pkg/package.go @@ -14,7 +14,7 @@ var errPackageNotFound = errors.New("package not found") // Load loads package by its import path func Load(path string) (*packages.Package, error) { - cfg := &packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles} + cfg := &packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps} pkgs, err := packages.Load(cfg, path) if err != nil { return nil, err