Skip to content

Commit

Permalink
Load source package with imports
Browse files Browse the repository at this point in the history
This way the generator won't have to load the same packages multiple
times while processing interfaces.
  • Loading branch information
rgngl committed Jan 29, 2022
1 parent e8ff957 commit ea7b21a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 75 deletions.
44 changes: 22 additions & 22 deletions generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package generator

import (
"bytes"
"fmt"
"path/filepath"
"sort"
"strings"
Expand Down Expand Up @@ -184,7 +185,7 @@ func NewGenerator(options Options) (*Generator, error) {
options.Imports = append(options.Imports, `"`+srcPackage.PkgPath+`"`)
}

methods, imports, err := findInterface(fs, srcPackageAST, options.InterfaceName)
methods, imports, err := findInterface(fs, srcPackage, srcPackageAST, options.InterfaceName)
if err != nil {
return nil, errors.Wrap(err, "failed to parse interface declaration")
}
Expand Down Expand Up @@ -291,7 +292,7 @@ var errInterfaceNotFound = errors.New("interface type declaration not found")
// findInterface looks for the interface declaration in the given directory
// and returns a list of the interface's methods and a list of imports from the file
// where interface type declaration was found
func findInterface(fs *token.FileSet, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) {
func findInterface(fs *token.FileSet, currentPackage *packages.Package, p *ast.Package, interfaceName string) (methods methodsList, imports []*ast.ImportSpec, err error) {
var found bool
var types []*ast.TypeSpec
var it *ast.InterfaceType
Expand All @@ -317,7 +318,7 @@ func findInterface(fs *token.FileSet, p *ast.Package, interfaceName string) (met
return nil, nil, errors.Wrap(errInterfaceNotFound, interfaceName)
}

methods, err = processInterface(fs, it, types, p.Name, imports)
methods, err = processInterface(fs, currentPackage, it, types, p.Name, imports)
if err != nil {
return nil, nil, err
}
Expand All @@ -341,7 +342,7 @@ func typeSpecs(f *ast.File) []*ast.TypeSpec {
return result
}

func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) {
func processInterface(fs *token.FileSet, currentPackage *packages.Package, it *ast.InterfaceType, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methods methodsList, err error) {
if it.Methods == nil {
return nil, nil
}
Expand All @@ -361,9 +362,9 @@ func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.Typ
continue
}
case *ast.SelectorExpr:
embeddedMethods, err = processSelector(fs, v, imports)
embeddedMethods, err = processSelector(fs, currentPackage, v, imports)
case *ast.Ident:
embeddedMethods, err = processIdent(fs, v, types, typesPrefix, imports)
embeddedMethods, err = processIdent(fs, currentPackage, v, types, typesPrefix, imports)
}

if err != nil {
Expand All @@ -379,26 +380,26 @@ func processInterface(fs *token.FileSet, it *ast.InterfaceType, types []*ast.Typ
return methods, nil
}

func processSelector(fs *token.FileSet, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) {
func processSelector(fs *token.FileSet, currentPackage *packages.Package, se *ast.SelectorExpr, imports []*ast.ImportSpec) (methodsList, error) {
interfaceName := se.Sel.Name
packageSelector := se.X.(*ast.Ident).Name

importPath, err := importPathByPrefix(imports, packageSelector)
importPath, err := findImportPathForName(packageSelector, imports, currentPackage)
if err != nil {
return nil, errors.Wrapf(err, "unable to load %s.%s", packageSelector, interfaceName)
return nil, errors.Wrapf(err, "unable to find package %s", packageSelector)
}

p, err := pkg.Load(importPath)
if err != nil {
return nil, errors.Wrap(err, "failed to load imported package")
p, ok := currentPackage.Imports[importPath]
if !ok {
return nil, fmt.Errorf("unable to find package %s", packageSelector)
}

astPkg, err := pkg.AST(fs, p)
if err != nil {
return nil, errors.Wrap(err, "failed to import package")
}

methods, _, err := findInterface(fs, astPkg, interfaceName)
methods, _, err := findInterface(fs, p, astPkg, interfaceName)

return methods, err
}
Expand Down Expand Up @@ -431,7 +432,7 @@ func mergeMethods(ml1, ml2 methodsList) (methodsList, error) {
var errEmbeddedInterfaceNotFound = errors.New("embedded interface not found")
var errNotAnInterface = errors.New("embedded type is not an interface")

func processIdent(fs *token.FileSet, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) {
func processIdent(fs *token.FileSet, currentPackage *packages.Package, i *ast.Ident, types []*ast.TypeSpec, typesPrefix string, imports []*ast.ImportSpec) (methodsList, error) {
var embeddedInterface *ast.InterfaceType
for _, t := range types {
if t.Name.Name == i.Name {
Expand All @@ -448,26 +449,25 @@ func processIdent(fs *token.FileSet, i *ast.Ident, types []*ast.TypeSpec, typesP
return nil, errors.Wrap(errEmbeddedInterfaceNotFound, i.Name)
}

return processInterface(fs, embeddedInterface, types, typesPrefix, imports)
return processInterface(fs, currentPackage, embeddedInterface, types, typesPrefix, imports)
}

var errUnknownSelector = errors.New("unknown selector")

func importPathByPrefix(imports []*ast.ImportSpec, prefix string) (string, error) {
func findImportPathForName(name string, imports []*ast.ImportSpec, currentPackage *packages.Package) (string, error) {
for _, i := range imports {
if i.Name != nil && i.Name.Name == prefix {
if i.Name != nil && i.Name.Name == name {
return unquote(i.Path.Value), nil
}
}

for _, i := range imports {
p, err := pkg.Load(unquote(i.Path.Value))
if err == nil && p.Name == prefix {
return p.PkgPath, nil
for path, pkg := range currentPackage.Imports {
if pkg.Name == name {
return path, nil
}
}

return "", errUnknownSelector
return "", errors.Wrapf(errUnknownSelector, name)
}

func unquote(s string) string {
Expand Down
95 changes: 43 additions & 52 deletions generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func Test_unquote(t *testing.T) {
Expand Down Expand Up @@ -51,77 +52,59 @@ func Test_unquote(t *testing.T) {
}
}

func Test_importPathByPrefix(t *testing.T) {
func Test_findImportPathForName(t *testing.T) {
type args struct {
name string
imports []*ast.ImportSpec
prefix string
cp *packages.Package
}
tests := []struct {
name string
args args

inspectErr func(*testing.T, error)

want1 string
wantErr bool
want string
wantErr error
}{
{
name: "prefix in import statement",
args: args{
imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "myio"}, Path: &ast.BasicLit{Value: "io"}}},
prefix: "myio",
},
wantErr: false,
want1: "io",
},
{
name: "failed to load package",
name: "path from import name",
args: args{
imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "myio"}, Path: &ast.BasicLit{Value: "unexisting_package"}}},
prefix: "unexisting_package",
name: "pkg",
imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "pkg"}, Path: &ast.BasicLit{Value: "domain/pkgname"}}},
},
wantErr: true,
want: "domain/pkgname",
},
{
name: "success",
name: "path from package imports",
args: args{
imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}},
prefix: "io",
name: "pkg",
cp: &packages.Package{
Imports: map[string]*packages.Package{
"domain/pkgname": {
Name: "pkg",
},
},
},
},
want1: "io",
wantErr: false,
want: "domain/pkgname",
},
{
name: "unknown prefix",
name: "not found",
args: args{
imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}},
prefix: "unknown_prefix",
name: "pkg",
cp: &packages.Package{},
},
inspectErr: func(t *testing.T, err error) {
assert.Equal(t, errUnknownSelector, err)
},
wantErr: true,
wantErr: errors.Wrapf(errUnknownSelector, "pkg"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mc := minimock.NewController(t)
defer mc.Wait(time.Second)

got1, err := importPathByPrefix(tt.args.imports, tt.args.prefix)

assert.Equal(t, tt.want1, got1, "importPathByPrefix returned unexpected result")

if tt.wantErr {
assert.Error(t, err)
if tt.inspectErr != nil {
tt.inspectErr(t, err)
}
path, err := findImportPathForName(tt.args.name, tt.args.imports, tt.args.cp)
if tt.wantErr != nil {
assert.EqualError(t, err, tt.wantErr.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, path)
}

})
}
}
Expand Down Expand Up @@ -179,7 +162,7 @@ func Test_processIdent(t *testing.T) {
mc := minimock.NewController(t)
defer mc.Wait(time.Second)

got1, err := processIdent(tt.args.fs, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports)
got1, err := processIdent(tt.args.fs, nil, tt.args.i, tt.args.types, tt.args.typesPrefix, tt.args.imports)

assert.Equal(t, tt.want1, got1, "processIdent returned unexpected result")

Expand Down Expand Up @@ -266,6 +249,7 @@ func Test_mergeMethods(t *testing.T) {
func Test_processSelector(t *testing.T) {
type args struct {
fs *token.FileSet
cp *packages.Package
se *ast.SelectorExpr
imports []*ast.ImportSpec
}
Expand All @@ -278,17 +262,19 @@ func Test_processSelector(t *testing.T) {
inspectErr func(err error, t *testing.T)
}{
{
name: "import not found",
name: "import with name not found",
args: args{
se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknown"}, Sel: &ast.Ident{Name: "unknown"}},
cp: &packages.Package{Imports: nil},
},
wantErr: true,
},
{
name: "import failed",
name: "import not found",
args: args{
se: &ast.SelectorExpr{X: &ast.Ident{Name: "unknownpackage"}, Sel: &ast.Ident{Name: "Unknown"}},
imports: []*ast.ImportSpec{{Name: &ast.Ident{Name: "unknownpackage"}, Path: &ast.BasicLit{Value: "unknown_path"}}},
imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "unknown_path"}}},
cp: &packages.Package{Imports: nil},
},
wantErr: true,
},
Expand All @@ -298,14 +284,17 @@ func Test_processSelector(t *testing.T) {
se: &ast.SelectorExpr{X: &ast.Ident{Name: "io"}, Sel: &ast.Ident{Name: "UnknownInterface"}},
imports: []*ast.ImportSpec{{Path: &ast.BasicLit{Value: "io"}}},
fs: token.NewFileSet(),
cp: &packages.Package{Imports: map[string]*packages.Package{
"io": {},
}},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got1, err := processSelector(tt.args.fs, tt.args.se, tt.args.imports)
got1, err := processSelector(tt.args.fs, tt.args.cp, tt.args.se, tt.args.imports)

assert.Equal(t, tt.want1, got1, "processSelector returned unexpected result")

Expand All @@ -324,6 +313,7 @@ func Test_processSelector(t *testing.T) {
func Test_processInterface(t *testing.T) {
type args struct {
fs *token.FileSet
cp *packages.Package
it *ast.InterfaceType
types []*ast.TypeSpec
typesPrefix string
Expand All @@ -350,6 +340,7 @@ func Test_processInterface(t *testing.T) {
name: "selector expression",
args: args{
fs: token.NewFileSet(),
cp: &packages.Package{Imports: nil},
it: &ast.InterfaceType{Methods: &ast.FieldList{List: []*ast.Field{
{
Names: []*ast.Ident{{Name: "methodName"}},
Expand Down Expand Up @@ -390,7 +381,7 @@ func Test_processInterface(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got1, err := processInterface(tt.args.fs, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports)
got1, err := processInterface(tt.args.fs, tt.args.cp, tt.args.it, tt.args.types, tt.args.typesPrefix, tt.args.imports)

assert.Equal(t, tt.want1, got1, "processInterface returned unexpected result")

Expand Down Expand Up @@ -462,7 +453,7 @@ func Test_findInterface(t *testing.T) {
mc := minimock.NewController(t)
defer mc.Wait(time.Second)

got1, _, err := findInterface(tt.args.fs, tt.args.p, tt.args.interfaceName)
got1, _, err := findInterface(tt.args.fs, nil, tt.args.p, tt.args.interfaceName)

assert.Equal(t, tt.want1, got1, "findInterface returned unexpected result")

Expand Down
2 changes: 1 addition & 1 deletion pkg/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ var errPackageNotFound = errors.New("package not found")

// Load loads package by its import path
func Load(path string) (*packages.Package, error) {
cfg := &packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles}
cfg := &packages.Config{Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps}
pkgs, err := packages.Load(cfg, path)
if err != nil {
return nil, err
Expand Down

0 comments on commit ea7b21a

Please sign in to comment.