From 313d65676557996c64daa399593ed980e6db952f Mon Sep 17 00:00:00 2001 From: LandonTClipp Date: Fri, 20 Oct 2023 12:54:21 -0500 Subject: [PATCH] updates --- cmd/mockery.go | 11 ++--- pkg/interface.go | 34 +++++++++++++++ pkg/method.go | 88 +++++++++++++++++++++++++++++++++++++++ pkg/outputter.go | 2 - pkg/parse.go | 31 -------------- pkg/registry/package.go | 24 +++++------ pkg/registry/registry.go | 32 +++----------- pkg/template_generator.go | 16 +++++-- 8 files changed, 156 insertions(+), 82 deletions(-) create mode 100644 pkg/interface.go create mode 100644 pkg/method.go diff --git a/cmd/mockery.go b/cmd/mockery.go index 922a16a89..64e98721f 100644 --- a/cmd/mockery.go +++ b/cmd/mockery.go @@ -211,10 +211,6 @@ func (r *RootApp) Run() error { return nil } - var osp pkg.OutputStreamProvider - if r.Config.Print { - osp = &pkg.StdoutStreamProvider{} - } buildTags := strings.Split(r.Config.BuildTags, " ") var boilerplate string @@ -268,7 +264,7 @@ func (r *RootApp) Run() error { } ifaceLog.Debug().Msg("config specifies to generate this interface") - outputter := pkg.NewOutputter(&r.Config, boilerplate, true) + outputter := pkg.NewOutputter(&r.Config, boilerplate) if err := outputter.Generate(ifaceCtx, iface); err != nil { return err } @@ -277,6 +273,11 @@ func (r *RootApp) Run() error { return nil } + var osp pkg.OutputStreamProvider + if r.Config.Print { + osp = &pkg.StdoutStreamProvider{} + } + if r.Config.Name != "" && r.Config.All { log.Fatal().Msgf("Specify --name or --all, but not both") } else if (r.Config.FileName != "" || r.Config.StructName != "") && r.Config.All { diff --git a/pkg/interface.go b/pkg/interface.go new file mode 100644 index 000000000..641681071 --- /dev/null +++ b/pkg/interface.go @@ -0,0 +1,34 @@ +package pkg + +import ( + "go/ast" + "go/types" +) + +// Interface type represents the target type that we will generate a mock for. +// It could be an interface, or a function type. +// Function type emulates: an interface it has 1 method with the function signature +// and a general name, e.g. "Execute". +type Interface struct { + Name string // Name of the type to be mocked. + QualifiedName string // Path to the package of the target type. + FileName string + File *ast.File + Pkg TypesPackage + NamedType *types.Named + IsFunction bool // If true, this instance represents a function, otherwise it's an interface. + ActualInterface *types.Interface // Holds the actual interface type, in case it's an interface. + SingleFunction *Method // Holds the function type information, in case it's a function type. +} + +func (iface *Interface) Methods() []*Method { + if iface.IsFunction { + return []*Method{iface.SingleFunction} + } + methods := make([]*Method, iface.ActualInterface.NumMethods()) + for i := 0; i < iface.ActualInterface.NumMethods(); i++ { + fn := iface.ActualInterface.Method(i) + methods[i] = &Method{Name: fn.Name(), Signature: fn.Type().(*types.Signature)} + } + return methods +} diff --git a/pkg/method.go b/pkg/method.go new file mode 100644 index 000000000..07aedc8d6 --- /dev/null +++ b/pkg/method.go @@ -0,0 +1,88 @@ +package pkg + +import ( + "go/types" + "path" + "strings" +) + +type Method struct { + Name string + Signature *types.Signature +} + +type Imports map[string]*types.Package + +func (m Method) populateImports(imports Imports) { + for i := 0; i < m.Signature.Params().Len(); i++ { + m.importsHelper(m.Signature.Params().At(i).Type(), imports) + } +} + +// stripVendorPath strips the vendor dir prefix from a package path. +// For example we might encounter an absolute path like +// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved +// to github.com/pkg/errors. +func stripVendorPath(p string) string { + parts := strings.Split(p, "/vendor/") + if len(parts) == 1 { + return p + } + return strings.TrimLeft(path.Join(parts[1:]...), "/") +} + +// importsHelper extracts all the package imports for a given type +// recursively. The imported packages by a single type can be more than +// one (ex: map[a.Type]b.Type). +func (m Method) importsHelper(elem types.Type, imports map[string]*types.Package) { + switch t := elem.(type) { + case *types.Named: + if pkg := t.Obj().Pkg(); pkg != nil { + imports[stripVendorPath(pkg.Path())] = pkg + } + // The imports of a Type with a TypeList must be added to the imports list + // For example: Foo[otherpackage.Bar] , must have otherpackage imported + if targs := t.TypeArgs(); targs != nil { + for i := 0; i < targs.Len(); i++ { + m.importsHelper(targs.At(i), imports) + } + } + + case *types.Array: + m.importsHelper(t.Elem(), imports) + + case *types.Slice: + m.importsHelper(t.Elem(), imports) + + case *types.Signature: + for i := 0; i < t.Params().Len(); i++ { + m.importsHelper(t.Params().At(i).Type(), imports) + } + for i := 0; i < t.Results().Len(); i++ { + m.importsHelper(t.Results().At(i).Type(), imports) + } + + case *types.Map: + m.importsHelper(t.Key(), imports) + m.importsHelper(t.Elem(), imports) + + case *types.Chan: + m.importsHelper(t.Elem(), imports) + + case *types.Pointer: + m.importsHelper(t.Elem(), imports) + + case *types.Struct: // anonymous struct + for i := 0; i < t.NumFields(); i++ { + m.importsHelper(t.Field(i).Type(), imports) + } + + case *types.Interface: // anonymous interface + for i := 0; i < t.NumExplicitMethods(); i++ { + m.importsHelper(t.ExplicitMethod(i).Type(), imports) + } + for i := 0; i < t.NumEmbeddeds(); i++ { + m.importsHelper(t.EmbeddedType(i), imports) + } + } +} diff --git a/pkg/outputter.go b/pkg/outputter.go index 9860513f7..82cf76d17 100644 --- a/pkg/outputter.go +++ b/pkg/outputter.go @@ -297,12 +297,10 @@ type Outputter struct { func NewOutputter( config *config.Config, boilerplate string, - dryRun bool, ) *Outputter { return &Outputter{ boilerplate: boilerplate, config: config, - dryRun: dryRun, } } diff --git a/pkg/parse.go b/pkg/parse.go index e85077a96..c3836feb3 100644 --- a/pkg/parse.go +++ b/pkg/parse.go @@ -261,43 +261,12 @@ func (p *Parser) packageInterfaces( return ifaces } -type Method struct { - Name string - Signature *types.Signature -} - type TypesPackage interface { Name() string Path() string } -// Interface type represents the target type that we will generate a mock for. -// It could be an interface, or a function type. -// Function type emulates: an interface it has 1 method with the function signature -// and a general name, e.g. "Execute". -type Interface struct { - Name string // Name of the type to be mocked. - QualifiedName string // Path to the package of the target type. - FileName string - File *ast.File - Pkg TypesPackage - NamedType *types.Named - IsFunction bool // If true, this instance represents a function, otherwise it's an interface. - ActualInterface *types.Interface // Holds the actual interface type, in case it's an interface. - SingleFunction *Method // Holds the function type information, in case it's a function type. -} -func (iface *Interface) Methods() []*Method { - if iface.IsFunction { - return []*Method{iface.SingleFunction} - } - methods := make([]*Method, iface.ActualInterface.NumMethods()) - for i := 0; i < iface.ActualInterface.NumMethods(); i++ { - fn := iface.ActualInterface.Method(i) - methods[i] = &Method{Name: fn.Name(), Signature: fn.Type().(*types.Signature)} - } - return methods -} type sortableIFaceList []*Interface diff --git a/pkg/registry/package.go b/pkg/registry/package.go index 376824242..d4945ea68 100644 --- a/pkg/registry/package.go +++ b/pkg/registry/package.go @@ -1,20 +1,26 @@ package registry import ( - "go/types" "path" "strings" ) +type TypesPackage interface { + Name() string + Path() string +} + // Package represents an imported package. type Package struct { - pkg *types.Package + pkg TypesPackage Alias string } // NewPackage creates a new instance of Package. -func NewPackage(pkg *types.Package) *Package { return &Package{pkg: pkg} } +func NewPackage(pkg TypesPackage) *Package { + return &Package{pkg: pkg} +} // Qualifier returns the qualifier which must be used to refer to types // declared in the package. @@ -66,17 +72,7 @@ func (p Package) uniqueName(lvl int) string { return name } -// stripVendorPath strips the vendor dir prefix from a package path. -// For example we might encounter an absolute path like -// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved -// to github.com/pkg/errors. -func stripVendorPath(p string) string { - parts := strings.Split(p, "/vendor/") - if len(parts) == 1 { - return p - } - return strings.TrimLeft(path.Join(parts[1:]...), "/") -} + func min(a, b int) int { if a < b { diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index a237cdce9..75a70c04e 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -5,10 +5,10 @@ import ( "fmt" "go/ast" "go/types" - "path/filepath" "sort" "strings" + "github.com/chigopher/pathlib" "golang.org/x/tools/go/packages" ) @@ -19,25 +19,18 @@ import ( type Registry struct { srcPkgName string srcPkgTypes *types.Package - moqPkgPath string + outputPath *pathlib.Path aliases map[string]string imports map[string]*Package } // New loads the source package info and returns a new instance of // Registry. -func New(srcDir, moqPkg string) (*Registry, error) { - srcPkg, err := pkgInfoFromPath( - srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes, - ) - if err != nil { - return nil, fmt.Errorf("couldn't load source package: %s", err) - } - +func New(srcPkg *packages.Package, outputPath *pathlib.Path) (*Registry, error) { return &Registry{ srcPkgName: srcPkg.Name, srcPkgTypes: srcPkg.Types, - moqPkgPath: findPkgPath(moqPkg, srcPkg.PkgPath), + outputPath: outputPath, aliases: parseImportsAliases(srcPkg.Syntax), imports: make(map[string]*Package), }, nil @@ -78,7 +71,6 @@ func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypePar func (r *Registry) MethodScope() *MethodScope { return &MethodScope{ registry: r, - moqPkgPath: r.moqPkgPath, conflicted: map[string]bool{}, } } @@ -88,7 +80,7 @@ func (r *Registry) MethodScope() *MethodScope { // packages. func (r *Registry) AddImport(pkg *types.Package) *Package { path := stripVendorPath(pkg.Path()) - if path == r.moqPkgPath { + if pathlib.NewPath(path).Equals(r.outputPath) { return nil } @@ -176,20 +168,6 @@ func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, return pkgs[0], nil } -func findPkgPath(pkgInputVal string, srcPkgPath string) string { - if pkgInputVal == "" { - return srcPkgPath - } - if pkgInDir(srcPkgPath, pkgInputVal) { - return srcPkgPath - } - subdirectoryPath := filepath.Join(srcPkgPath, pkgInputVal) - if pkgInDir(subdirectoryPath, pkgInputVal) { - return subdirectoryPath - } - return "" -} - func pkgInDir(pkgName, dir string) bool { currentPkg, err := pkgInfoFromPath(dir, packages.NeedName) if err != nil { diff --git a/pkg/template_generator.go b/pkg/template_generator.go index 76d75fdf5..27da85689 100644 --- a/pkg/template_generator.go +++ b/pkg/template_generator.go @@ -1,7 +1,7 @@ package pkg import ( - "github.com/vektra/mockery/v2/pkg/registry" + "github.com/vektra/mockery/v2/pkg/config" "github.com/vektra/mockery/v2/pkg/template" ) @@ -18,12 +18,22 @@ func NewTemplateGenerator(config TemplateGeneratorConfig) *TemplateGenerator { } } -func (g *TemplateGenerator) Generate() error { +func (g *TemplateGenerator) Generate(iface *Interface, ifaceConfig *config.Config) error { templ, err := template.New(g.config.Style) if err != nil { return err } - data := registry. + imports := Imports{} + for _, method := range iface.Methods() { + method.populateImports(imports) + } + // TODO: Work on getting these imports into the template + + data := template.Data{ + PkgName: ifaceConfig.Outpkg, + SrcPkgQualifier: iface.Pkg.Name() + ".", + Imports: + } return nil }